Compare commits
426 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0cf17e1c63 | |||
| b459a2f7a9 | |||
| 20bea9cd7f | |||
| a7709d489c | |||
| 18dfc997b8 | |||
| 92d0b6addf | |||
| f305745295 | |||
| fc22586752 | |||
| 646440eba3 | |||
| 53e5579326 | |||
| 29a1630d0f | |||
| 171f4ab2ae | |||
| a5ae071a03 | |||
| b6e2634537 | |||
| 23146c8dae | |||
| 9f424f2fc0 | |||
| 58b60b84fd | |||
| 86aef3319f | |||
| 0015b3d43d | |||
| 9c4d44c057 | |||
| 800c7fbe11 | |||
| 291ba24229 | |||
| ffa4096390 | |||
| f2b6fc6948 | |||
| acff8a0ece | |||
| 347c222f78 | |||
| f58619e378 | |||
| 472cfe1437 | |||
| 8b7efe27c1 | |||
| eb00c10d9b | |||
| 71249f4f88 | |||
| 0beeda3eec | |||
| dc4a40468b | |||
| 7fa2295d30 | |||
| 756f013ecd | |||
| a963d49306 | |||
| 4b00852bdf | |||
| b9b1731dc1 | |||
| 34791e6bbd | |||
| d1ebdfc92f | |||
| 33040b7978 | |||
| 3b6b6c48a5 | |||
| c3fddd3c8c | |||
| 41e5558715 | |||
| 58969085bf | |||
| f45ad2d543 | |||
| 0030d6b499 | |||
| 5f019f44ca | |||
| 0d602f92a3 | |||
| b10d617166 | |||
| 348c646bab | |||
| a8243e6746 | |||
| 9368828f94 | |||
| 51e9a3ecdf | |||
| 2f03605980 | |||
| 74e754b4e1 | |||
| f332e40000 | |||
| d6064147e4 | |||
| 1fb5005bf5 | |||
| 57fbb0479b | |||
| 26154cc648 | |||
| e207cee4ff | |||
| e7a2d957f5 | |||
| 7e5f02eebe | |||
| 248716c093 | |||
| 37a3fce27d | |||
| 7976c1dac7 | |||
| da2bac1b48 | |||
| 4096eba564 | |||
| 3f3a23e4b2 | |||
| 934e3145b8 | |||
| 6155ccbf4d | |||
| 6cadc81be8 | |||
| 412521edb0 | |||
| ec3be40ddd | |||
| fd00471189 | |||
| 94197cbcb9 | |||
| 65c3fcf76d | |||
| 83f77af2ab | |||
| 2fe83187d6 | |||
| e65052c237 | |||
| 38bc7c12ae | |||
| 758c5157b8 | |||
| ce6b47c0d4 | |||
| 22c95b62ce | |||
| 9684311176 | |||
| aa0fff8ac5 | |||
| a1229d8e98 | |||
| ad1b10db63 | |||
| 96308637d6 | |||
| e8a4cc908c | |||
| 3c8ac436bd | |||
| 4d341611a4 | |||
| ef94bfe1fb | |||
| a58b52f420 | |||
| 7852990073 | |||
| 14c9478080 | |||
| c5ebd91651 | |||
| 088f3cc817 | |||
| 50087bb24c | |||
| ca06465305 | |||
| ea719d5441 | |||
| 2627b6e69c | |||
| c869e1955a | |||
| 8293f75152 | |||
| 3ccf4bc383 | |||
| e71d850b79 | |||
| 774911b46c | |||
| 480ade22ce | |||
| bd31323876 | |||
| 2f3b8b27b8 | |||
| d39abf4312 | |||
| ec7058414f | |||
| 8dc63771ca | |||
| 434f1d7298 | |||
| ee0ae20d06 | |||
| a7e16c84a5 | |||
| eaa54d9d4a | |||
| 2c4d034536 | |||
| a43b7c9403 | |||
| 752979da01 | |||
| c4be938b7f | |||
| 3a308ba67e | |||
| cadf401f23 | |||
| 24dd41410a | |||
| 2abf43ed21 | |||
| 2e5ed77909 | |||
| 0ae0bfda83 | |||
| 22007e7aa9 | |||
| 05dde7414f | |||
| 721cfb1ac8 | |||
| 5973168a8c | |||
| 56ed24a092 | |||
| ca031f3ee1 | |||
| 3ee6d98905 | |||
| c9f3de1af6 | |||
| d8d4b9399e | |||
| 30bf1da424 | |||
| 6712fa9a8a | |||
| 2306b13fdc | |||
| 14907a7c6e | |||
| 967cbf814b | |||
| 0dfec38b4b | |||
| 9ad4702c08 | |||
| ec89bf3622 | |||
| ab7c924b9a | |||
| 0c2a2f31f6 | |||
| 2b52ed6397 | |||
| 1b2befaae9 | |||
| bca56f8ff6 | |||
| e9f7f75c34 | |||
| 69cd9ab9f5 | |||
| fa1bba3320 | |||
| 9b23668136 | |||
| bf347d5e78 | |||
| 3cfc88c4d6 | |||
| 031b20574c | |||
| f37448e602 | |||
| ab5b1a254f | |||
| 5d8996fe54 | |||
| 6b30c2e8e7 | |||
| 1298d4b379 | |||
| ac3aaa9348 | |||
| bedc0eadf3 | |||
| fe352ea54e | |||
| 7c990dd90a | |||
| f93111c319 | |||
| d4b2c82d54 | |||
| 169827636f | |||
| a96cd546c8 | |||
| eb33d4f1c2 | |||
| b6ef35fe55 | |||
| 4253956326 | |||
| 6fb84b6889 | |||
| 6e94402a8d | |||
| d68b822687 | |||
| 64299e959a | |||
| d14d23b010 | |||
| 30f1c700ce | |||
| ccae478347 | |||
| 3a2639f565 | |||
| e241ec3341 | |||
| bc6f70933b | |||
| bc070c3e39 | |||
| f30f42a4d3 | |||
| e4c95c7a91 | |||
| bfb1a81b7a | |||
| 257e36615a | |||
| 2a049df099 | |||
| 2194301260 | |||
| 095dd05b17 | |||
| 6d03934452 | |||
| 5051f44543 | |||
| 9d98f9f678 | |||
| 9e0c24cd3a | |||
| b66eec1e66 | |||
| aca66d60ed | |||
| 8316e7c0e9 | |||
| 3bbecad044 | |||
| a8eb7127aa | |||
| ba2889faf8 | |||
| 1e6c5b8e11 | |||
| 1199c02bfd | |||
| 688451b2a9 | |||
| 9ef3628209 | |||
| 8695f3fea0 | |||
| 88b094b5de | |||
| 8b3b0c51f5 | |||
| 322ff7c470 | |||
| ad968a0b54 | |||
| 5d79a7078c | |||
| e4f451e3f5 | |||
| d8496c47f0 | |||
| 9c28284331 | |||
| 075e9179c1 | |||
| e61bdfc417 | |||
| f6c5c5cadb | |||
| 8923011304 | |||
| e6900647f8 | |||
| c441494c2f | |||
| e1bea18357 | |||
| 197f4f984a | |||
| 0381a5c87b | |||
| 112b1baf2e | |||
| c61c958964 | |||
| a59d6ac6db | |||
| 37b9be3ff6 | |||
| 9d39c09e27 | |||
| ff38962ff2 | |||
| 121f33687a | |||
| 598cc8b078 | |||
| 3605f3705b | |||
| 407816ddbf | |||
| 6acdb65c1c | |||
| a4b0c66564 | |||
| d1e6101a0f | |||
| 330fbb19ac | |||
| 8cc431ee52 | |||
| 39831cf4b1 | |||
| bc8cdfd6da | |||
| 500876d65e | |||
| e59bb2d83f | |||
| 03910d531f | |||
| a122345f9c | |||
| 6d025c808a | |||
| 8525aec49c | |||
| b0435a188f | |||
| 3eb964eff2 | |||
| ed88129b00 | |||
| e1d8624483 | |||
| 68264b54d9 | |||
| fc36a5e607 | |||
| 1631d01dd2 | |||
| e846ad6ea7 | |||
| e57cad7159 | |||
| 0cf9e39f6f | |||
| 852332483a | |||
| 2b8604610c | |||
| d6b05bf337 | |||
| b07aff1be3 | |||
| f3df70e8fe | |||
| 9230ac6c20 | |||
| 5cf25c6f10 | |||
| d064c98998 | |||
| 25fabd8068 | |||
| 396e5c35a6 | |||
| 0a8c30c3da | |||
| 798f3cfd36 | |||
| 69ad0be5ff | |||
| 60f2e674ec | |||
| 6bb256e277 | |||
| 81ad85db5e | |||
| ed25ef7562 | |||
| d9c696aa22 | |||
| 22358a2d83 | |||
| 39a2a34380 | |||
| 07077dbb52 | |||
| e1346ae557 | |||
| 4f3d34d01e | |||
| 8516eba7c5 | |||
| 63010d45b2 | |||
| 59db8f99d7 | |||
| 236e8e8638 | |||
| 3279686342 | |||
| b6a77ffd7e | |||
| e0544a57f9 | |||
| 82c32e8d9f | |||
| a180d78d0c | |||
| 9be036aa37 | |||
| 8c39dad22d | |||
| 0a7aa62c45 | |||
| cbd34db278 | |||
| 414d86f2f0 | |||
| 4852d7f63b | |||
| 1165858a58 | |||
| 4575540d69 | |||
| 051aa4f065 | |||
| 6834dcfcb7 | |||
| 95c481ae52 | |||
| 5c266d6920 | |||
| 7fe21d91f2 | |||
| 751715bffe | |||
| a6bda9628c | |||
| ac646603c9 | |||
| 551e648be7 | |||
| 2f852a7eba | |||
| 7d462ff976 | |||
| d1cfef5d8a | |||
| f3c9c591bf | |||
| 0bbe2d5889 | |||
| aa341317f5 | |||
| 6ae38b66ba | |||
| 40e39d29f8 | |||
| 6d7d472792 | |||
| dae63214d5 | |||
| 46bdedcabb | |||
| 5fbaae5d8d | |||
| c9bc2b287e | |||
| 5b46132c81 | |||
| 7e65ab0b36 | |||
| 8a86787b64 | |||
| b2acfb5447 | |||
| 10ea23be34 | |||
| 37a0324c05 | |||
| 837ef2da59 | |||
| e0bc265bb2 | |||
| a39afbea23 | |||
| 7375b26925 | |||
| 3626051b1a | |||
| fbcdaf7c6d | |||
| 6934b331d4 | |||
| 734fe1e4d7 | |||
| d900f38f64 | |||
| 1c78174aaf | |||
| b897e5bdf2 | |||
| 09dd990273 | |||
| af3b8b1b80 | |||
| fc539a5d7b | |||
| d558bf4f60 | |||
| 99efbe03bb | |||
| 5168ed3cd4 | |||
| f614ee7f15 | |||
| 02330653ee | |||
| ae37d9816e | |||
| 7351675795 | |||
| fa5d5057f4 | |||
| 854a867597 | |||
| 35ef467dbe | |||
| 89dbc638e1 | |||
| 4eac1b9e97 | |||
| 80f938a7af | |||
| 2f7cf3bc57 | |||
| 1a7ed9c962 | |||
| 7004fffc08 | |||
| 06535192e6 | |||
| 5923147a71 | |||
| acaa89f584 | |||
| e6af1f64ac | |||
| 53aebd5cea | |||
| d64020e024 | |||
| 975a002796 | |||
| 6e6b83848f | |||
| 3fb255c906 | |||
| cd51d663fb | |||
| 28b0b6206b | |||
| 9859dc65e0 | |||
| 5c2288fbf5 | |||
| 1b47d1cad4 | |||
| 126bbf17c3 | |||
| 995ab8faaf | |||
| 9d1b1ab9d4 | |||
| 7e630b9416 | |||
| 14faca3933 | |||
| e8c9cc65dc | |||
| f0deedb1f8 | |||
| 70693f4824 | |||
| 3ee380d98f | |||
| b9b0c2c844 | |||
| c53acfdf77 | |||
| 08beffea33 | |||
| 7ed5006a70 | |||
| e009de1c9a | |||
| df7b950e6f | |||
| 7f3bc811b0 | |||
| f0c9d4e87f | |||
| 57781c520e | |||
| 05b18fb312 | |||
| 829783749c | |||
| 48b38e5d95 | |||
| 1527a05336 | |||
| 491e6585a4 | |||
| 8333ba6ec2 | |||
| a5fcb89991 | |||
| 3fd8f9f97a | |||
| 2180a60c21 | |||
| f64820a13e | |||
| 073be1f870 | |||
| 86686fc8f9 | |||
| 8fe51a8aa9 | |||
| 715df547bb | |||
| c454870ac8 | |||
| 68766fd131 | |||
| ce39cb7dde | |||
| e1663793c7 | |||
| e2f387965e | |||
| e75253f16a | |||
| 7d416f5421 | |||
| cdbcac68b8 | |||
| d52b6e8e56 | |||
| 510975619d | |||
| 49724b6da0 | |||
| 31b252c018 | |||
| dd2254989f | |||
| 7aa56b905c | |||
| 9f4948edbe | |||
| 2765c9fe93 | |||
| 8f223ee564 | |||
| b0e870d1db | |||
| 17bfbf9732 | |||
| da0c0acdcf | |||
| ea4c56108b | |||
| 73ba72ee52 | |||
| f67e0cc4ae | |||
| 8d4f107f63 | |||
| 7c6c3a8cc2 | |||
| 5930a3c95d |
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"hooks": {
|
||||
"PostToolUse": [
|
||||
{
|
||||
"matcher": "Edit|Write|NotebookEdit",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "ruff check --fix \"$CLAUDE_FILE_PATH\" 2>/dev/null; ruff format \"$CLAUDE_FILE_PATH\" 2>/dev/null; true"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,24 @@
|
||||
"Bash(ruff check:*)",
|
||||
"Bash(PYTHONPATH=core:exports python:*)",
|
||||
"mcp__agent-builder__list_tests",
|
||||
"mcp__agent-builder__generate_constraint_tests"
|
||||
"mcp__agent-builder__generate_constraint_tests",
|
||||
"Bash(python -m agent:*)",
|
||||
"Bash(python agent.py:*)",
|
||||
"Bash(python -c:*)",
|
||||
"Bash(done)",
|
||||
"Bash(xargs cat:*)",
|
||||
"mcp__agent-builder__list_mcp_tools",
|
||||
"mcp__agent-builder__add_mcp_server",
|
||||
"mcp__agent-builder__check_missing_credentials",
|
||||
"mcp__agent-builder__store_credential",
|
||||
"mcp__agent-builder__list_stored_credentials",
|
||||
"mcp__agent-builder__delete_stored_credential",
|
||||
"mcp__agent-builder__verify_credentials",
|
||||
"Bash(PYTHONPATH=/home/timothy/oss/hive/core:/home/timothy/oss/hive/exports python:*)",
|
||||
"Bash(PYTHONPATH=core:exports:tools/src python -m hubspot_input:*)",
|
||||
"mcp__agent-builder__export_graph"
|
||||
]
|
||||
}
|
||||
},
|
||||
"enabledMcpjsonServers": ["agent-builder", "tools"],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ metadata:
|
||||
- building-agents-construction
|
||||
- building-agents-patterns
|
||||
- testing-agent
|
||||
- setup-credentials
|
||||
---
|
||||
|
||||
# Agent Development Workflow
|
||||
@@ -21,10 +22,11 @@ Complete Standard Operating Procedure (SOP) for building production-ready goal-d
|
||||
|
||||
This workflow orchestrates specialized skills to take you from initial concept to production-ready agent:
|
||||
|
||||
1. **Understand Concepts** (5-10 min) → `/building-agents-core` (optional)
|
||||
2. **Build Structure** (15-30 min) → `/building-agents-construction`
|
||||
3. **Optimize Design** (10-15 min) → `/building-agents-patterns` (optional)
|
||||
4. **Test & Validate** (20-40 min) → `/testing-agent`
|
||||
1. **Understand Concepts** → `/building-agents-core` (optional)
|
||||
2. **Build Structure** → `/building-agents-construction`
|
||||
3. **Optimize Design** → `/building-agents-patterns` (optional)
|
||||
4. **Setup Credentials** → `/setup-credentials` (if agent uses tools requiring API keys)
|
||||
5. **Test & Validate** → `/testing-agent`
|
||||
|
||||
## When to Use This Workflow
|
||||
|
||||
@@ -44,6 +46,7 @@ Use this meta-skill when:
|
||||
"Need to understand agent concepts" → building-agents-core
|
||||
"Build a new agent" → building-agents-construction
|
||||
"Optimize my agent design" → building-agents-patterns
|
||||
"Set up API keys for my agent" → setup-credentials
|
||||
"Test my agent" → testing-agent
|
||||
"Not sure what I need" → Read phases below, then decide
|
||||
"Agent has structure but needs implementation" → See agent directory STATUS.md
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,80 @@
|
||||
# Online Research Agent
|
||||
|
||||
Deep-dive research agent that searches 10+ sources and produces comprehensive narrative reports with citations.
|
||||
|
||||
## Features
|
||||
|
||||
- Generates multiple search queries from a topic
|
||||
- Searches and fetches 15+ web sources
|
||||
- Evaluates and ranks sources by relevance
|
||||
- Synthesizes findings into themes
|
||||
- Writes narrative report with numbered citations
|
||||
- Quality checks for uncited claims
|
||||
- Saves report to local markdown file
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
# Show agent info
|
||||
python -m online_research_agent info
|
||||
|
||||
# Validate structure
|
||||
python -m online_research_agent validate
|
||||
|
||||
# Run research on a topic
|
||||
python -m online_research_agent run --topic "impact of AI on healthcare"
|
||||
|
||||
# Interactive shell
|
||||
python -m online_research_agent shell
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
```python
|
||||
from online_research_agent import default_agent
|
||||
|
||||
# Simple usage
|
||||
result = await default_agent.run({"topic": "climate change solutions"})
|
||||
|
||||
# Check output
|
||||
if result.success:
|
||||
print(f"Report saved to: {result.output['file_path']}")
|
||||
print(result.output['final_report'])
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
parse-query → search-sources → fetch-content → evaluate-sources
|
||||
↓
|
||||
write-report ← synthesize-findings
|
||||
↓
|
||||
quality-check → save-report
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
Reports are saved to `./research_reports/` as markdown files with:
|
||||
|
||||
1. Executive Summary
|
||||
2. Introduction
|
||||
3. Key Findings (by theme)
|
||||
4. Analysis
|
||||
5. Conclusion
|
||||
6. References
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.11+
|
||||
- LLM provider API key (Groq, Cerebras, etc.)
|
||||
- Internet access for web search/fetch
|
||||
|
||||
## Configuration
|
||||
|
||||
Edit `config.py` to change:
|
||||
|
||||
- `model`: LLM model (default: groq/moonshotai/kimi-k2-instruct-0905)
|
||||
- `temperature`: Generation temperature (default: 0.7)
|
||||
- `max_tokens`: Max tokens per response (default: 16384)
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Online Research Agent - Deep-dive research with narrative reports.
|
||||
|
||||
Research any topic by searching multiple sources, synthesizing information,
|
||||
and producing a well-structured narrative report with citations.
|
||||
"""
|
||||
|
||||
from .agent import OnlineResearchAgent, default_agent, goal, nodes, edges
|
||||
from .config import RuntimeConfig, AgentMetadata, default_config, metadata
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
__all__ = [
|
||||
"OnlineResearchAgent",
|
||||
"default_agent",
|
||||
"goal",
|
||||
"nodes",
|
||||
"edges",
|
||||
"RuntimeConfig",
|
||||
"AgentMetadata",
|
||||
"default_config",
|
||||
"metadata",
|
||||
]
|
||||
+158
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
CLI entry point for Online Research Agent.
|
||||
|
||||
Uses AgentRuntime for multi-entrypoint support with HITL pause/resume.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import click
|
||||
|
||||
from .agent import default_agent, OnlineResearchAgent
|
||||
|
||||
|
||||
def setup_logging(verbose=False, debug=False):
|
||||
"""Configure logging for execution visibility."""
|
||||
if debug:
|
||||
level, fmt = logging.DEBUG, "%(asctime)s %(name)s: %(message)s"
|
||||
elif verbose:
|
||||
level, fmt = logging.INFO, "%(message)s"
|
||||
else:
|
||||
level, fmt = logging.WARNING, "%(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=level, format=fmt, stream=sys.stderr)
|
||||
logging.getLogger("framework").setLevel(level)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version="1.0.0")
|
||||
def cli():
|
||||
"""Online Research Agent - Deep-dive research with narrative reports."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--topic", "-t", type=str, required=True, help="Research topic")
|
||||
@click.option("--mock", is_flag=True, help="Run in mock mode")
|
||||
@click.option("--quiet", "-q", is_flag=True, help="Only output result JSON")
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Show execution details")
|
||||
@click.option("--debug", is_flag=True, help="Show debug logging")
|
||||
def run(topic, mock, quiet, verbose, debug):
|
||||
"""Execute research on a topic."""
|
||||
if not quiet:
|
||||
setup_logging(verbose=verbose, debug=debug)
|
||||
|
||||
context = {"topic": topic}
|
||||
|
||||
result = asyncio.run(default_agent.run(context, mock_mode=mock))
|
||||
|
||||
output_data = {
|
||||
"success": result.success,
|
||||
"steps_executed": result.steps_executed,
|
||||
"output": result.output,
|
||||
}
|
||||
if result.error:
|
||||
output_data["error"] = result.error
|
||||
|
||||
click.echo(json.dumps(output_data, indent=2, default=str))
|
||||
sys.exit(0 if result.success else 1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--json", "output_json", is_flag=True)
|
||||
def info(output_json):
|
||||
"""Show agent information."""
|
||||
info_data = default_agent.info()
|
||||
if output_json:
|
||||
click.echo(json.dumps(info_data, indent=2))
|
||||
else:
|
||||
click.echo(f"Agent: {info_data['name']}")
|
||||
click.echo(f"Version: {info_data['version']}")
|
||||
click.echo(f"Description: {info_data['description']}")
|
||||
click.echo(f"\nNodes: {', '.join(info_data['nodes'])}")
|
||||
click.echo(f"Entry: {info_data['entry_node']}")
|
||||
click.echo(f"Terminal: {', '.join(info_data['terminal_nodes'])}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def validate():
|
||||
"""Validate agent structure."""
|
||||
validation = default_agent.validate()
|
||||
if validation["valid"]:
|
||||
click.echo("Agent is valid")
|
||||
else:
|
||||
click.echo("Agent has errors:")
|
||||
for error in validation["errors"]:
|
||||
click.echo(f" ERROR: {error}")
|
||||
sys.exit(0 if validation["valid"] else 1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--verbose", "-v", is_flag=True)
|
||||
def shell(verbose):
|
||||
"""Interactive research session."""
|
||||
asyncio.run(_interactive_shell(verbose))
|
||||
|
||||
|
||||
async def _interactive_shell(verbose=False):
|
||||
"""Async interactive shell."""
|
||||
setup_logging(verbose=verbose)
|
||||
|
||||
click.echo("=== Online Research Agent ===")
|
||||
click.echo("Enter a topic to research (or 'quit' to exit):\n")
|
||||
|
||||
agent = OnlineResearchAgent()
|
||||
await agent.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
topic = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Topic> "
|
||||
)
|
||||
if topic.lower() in ["quit", "exit", "q"]:
|
||||
click.echo("Goodbye!")
|
||||
break
|
||||
|
||||
if not topic.strip():
|
||||
continue
|
||||
|
||||
click.echo("\nResearching... (this may take a few minutes)\n")
|
||||
|
||||
result = await agent.trigger_and_wait("start", {"topic": topic})
|
||||
|
||||
if result is None:
|
||||
click.echo("\n[Execution timed out]\n")
|
||||
continue
|
||||
|
||||
if result.success:
|
||||
output = result.output
|
||||
if "file_path" in output:
|
||||
click.echo(f"\nReport saved to: {output['file_path']}\n")
|
||||
if "final_report" in output:
|
||||
click.echo("\n--- Report Preview ---\n")
|
||||
preview = (
|
||||
output["final_report"][:500] + "..."
|
||||
if len(output.get("final_report", "")) > 500
|
||||
else output.get("final_report", "")
|
||||
)
|
||||
click.echo(preview)
|
||||
click.echo("\n")
|
||||
else:
|
||||
click.echo(f"\nResearch failed: {result.error}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await agent.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,429 @@
|
||||
"""Agent graph construction for Online Research Agent."""
|
||||
|
||||
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.llm import LiteLLMProvider
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
from .config import default_config, metadata
|
||||
from .nodes import (
|
||||
parse_query_node,
|
||||
search_sources_node,
|
||||
fetch_content_node,
|
||||
evaluate_sources_node,
|
||||
synthesize_findings_node,
|
||||
write_report_node,
|
||||
quality_check_node,
|
||||
save_report_node,
|
||||
)
|
||||
|
||||
# Goal definition
|
||||
goal = Goal(
|
||||
id="comprehensive-online-research",
|
||||
name="Comprehensive Online Research",
|
||||
description="Research any topic by searching multiple sources, synthesizing information, and producing a well-structured narrative report with citations.",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="source-coverage",
|
||||
description="Query 10+ diverse sources",
|
||||
metric="source_count",
|
||||
target=">=10",
|
||||
weight=0.20,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="relevance",
|
||||
description="All sources directly address the query",
|
||||
metric="relevance_score",
|
||||
target="90%",
|
||||
weight=0.25,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="synthesis",
|
||||
description="Synthesize findings into coherent narrative",
|
||||
metric="coherence_score",
|
||||
target="85%",
|
||||
weight=0.25,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="citations",
|
||||
description="Include citations for all claims",
|
||||
metric="citation_coverage",
|
||||
target="100%",
|
||||
weight=0.15,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="actionable",
|
||||
description="Report answers the user's question",
|
||||
metric="answer_completeness",
|
||||
target="90%",
|
||||
weight=0.15,
|
||||
),
|
||||
],
|
||||
constraints=[
|
||||
Constraint(
|
||||
id="no-hallucination",
|
||||
description="Only include information found in sources",
|
||||
constraint_type="quality",
|
||||
category="accuracy",
|
||||
),
|
||||
Constraint(
|
||||
id="source-attribution",
|
||||
description="Every factual claim must cite its source",
|
||||
constraint_type="quality",
|
||||
category="accuracy",
|
||||
),
|
||||
Constraint(
|
||||
id="recency-preference",
|
||||
description="Prefer recent sources when relevant",
|
||||
constraint_type="quality",
|
||||
category="relevance",
|
||||
),
|
||||
Constraint(
|
||||
id="no-paywalled",
|
||||
description="Avoid sources that require payment to access",
|
||||
constraint_type="functional",
|
||||
category="accessibility",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Node list
|
||||
nodes = [
|
||||
parse_query_node,
|
||||
search_sources_node,
|
||||
fetch_content_node,
|
||||
evaluate_sources_node,
|
||||
synthesize_findings_node,
|
||||
write_report_node,
|
||||
quality_check_node,
|
||||
save_report_node,
|
||||
]
|
||||
|
||||
# Edge definitions
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="parse-to-search",
|
||||
source="parse-query",
|
||||
target="search-sources",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="search-to-fetch",
|
||||
source="search-sources",
|
||||
target="fetch-content",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="fetch-to-evaluate",
|
||||
source="fetch-content",
|
||||
target="evaluate-sources",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="evaluate-to-synthesize",
|
||||
source="evaluate-sources",
|
||||
target="synthesize-findings",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="synthesize-to-write",
|
||||
source="synthesize-findings",
|
||||
target="write-report",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="write-to-quality",
|
||||
source="write-report",
|
||||
target="quality-check",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="quality-to-save",
|
||||
source="quality-check",
|
||||
target="save-report",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
]
|
||||
|
||||
# Graph configuration
|
||||
entry_node = "parse-query"
|
||||
entry_points = {"start": "parse-query"}
|
||||
pause_nodes = []
|
||||
terminal_nodes = ["save-report"]
|
||||
|
||||
|
||||
class OnlineResearchAgent:
|
||||
"""
|
||||
Online Research Agent - Deep-dive research with narrative reports.
|
||||
|
||||
Uses AgentRuntime for multi-entrypoint support with HITL pause/resume.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config or default_config
|
||||
self.goal = goal
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.entry_node = entry_node
|
||||
self.entry_points = entry_points
|
||||
self.pause_nodes = pause_nodes
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self._runtime: AgentRuntime | None = None
|
||||
self._graph: GraphSpec | None = None
|
||||
|
||||
def _build_entry_point_specs(self) -> list[EntryPointSpec]:
|
||||
"""Convert entry_points dict to EntryPointSpec list."""
|
||||
specs = []
|
||||
for ep_id, node_id in self.entry_points.items():
|
||||
if ep_id == "start":
|
||||
trigger_type = "manual"
|
||||
name = "Start"
|
||||
elif "_resume" in ep_id:
|
||||
trigger_type = "resume"
|
||||
name = f"Resume from {ep_id.replace('_resume', '')}"
|
||||
else:
|
||||
trigger_type = "manual"
|
||||
name = ep_id.replace("-", " ").title()
|
||||
|
||||
specs.append(
|
||||
EntryPointSpec(
|
||||
id=ep_id,
|
||||
name=name,
|
||||
entry_node=node_id,
|
||||
trigger_type=trigger_type,
|
||||
isolation_level="shared",
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
def _create_runtime(self, mock_mode=False) -> AgentRuntime:
|
||||
"""Create AgentRuntime instance."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Persistent storage in ~/.hive for telemetry and run history
|
||||
storage_path = Path.home() / ".hive" / "online_research_agent"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
# Load MCP servers (always load, needed for tool validation)
|
||||
agent_dir = Path(__file__).parent
|
||||
mcp_config_path = agent_dir / "mcp_servers.json"
|
||||
|
||||
if mcp_config_path.exists():
|
||||
with open(mcp_config_path) as f:
|
||||
mcp_servers = json.load(f)
|
||||
|
||||
for server_config in mcp_servers.get("servers", []):
|
||||
# Resolve relative cwd paths
|
||||
cwd = server_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
server_config["cwd"] = str(agent_dir / cwd)
|
||||
tool_registry.register_mcp_server(server_config)
|
||||
|
||||
llm = None
|
||||
if not mock_mode:
|
||||
# LiteLLMProvider uses environment variables for API keys
|
||||
llm = LiteLLMProvider(
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
api_base=self.config.api_base,
|
||||
)
|
||||
|
||||
self._graph = GraphSpec(
|
||||
id="online-research-agent-graph",
|
||||
goal_id=self.goal.id,
|
||||
version="1.0.0",
|
||||
entry_node=self.entry_node,
|
||||
entry_points=self.entry_points,
|
||||
terminal_nodes=self.terminal_nodes,
|
||||
pause_nodes=self.pause_nodes,
|
||||
nodes=self.nodes,
|
||||
edges=self.edges,
|
||||
default_model=self.config.model,
|
||||
max_tokens=self.config.max_tokens,
|
||||
)
|
||||
|
||||
# Create AgentRuntime with all entry points
|
||||
self._runtime = create_agent_runtime(
|
||||
graph=self._graph,
|
||||
goal=self.goal,
|
||||
storage_path=storage_path,
|
||||
entry_points=self._build_entry_point_specs(),
|
||||
llm=llm,
|
||||
tools=list(tool_registry.get_tools().values()),
|
||||
tool_executor=tool_registry.get_executor(),
|
||||
)
|
||||
|
||||
return self._runtime
|
||||
|
||||
async def start(self, mock_mode=False) -> None:
|
||||
"""Start the agent runtime."""
|
||||
if self._runtime is None:
|
||||
self._create_runtime(mock_mode=mock_mode)
|
||||
await self._runtime.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the agent runtime."""
|
||||
if self._runtime is not None:
|
||||
await self._runtime.stop()
|
||||
|
||||
async def trigger(
|
||||
self,
|
||||
entry_point: str,
|
||||
input_data: dict,
|
||||
correlation_id: str | None = None,
|
||||
session_state: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Trigger execution at a specific entry point (non-blocking).
|
||||
|
||||
Args:
|
||||
entry_point: Entry point ID (e.g., "start", "pause-node_resume")
|
||||
input_data: Input data for the execution
|
||||
correlation_id: Optional ID to correlate related executions
|
||||
session_state: Optional session state to resume from (with paused_at, memory)
|
||||
|
||||
Returns:
|
||||
Execution ID for tracking
|
||||
"""
|
||||
if self._runtime is None or not self._runtime.is_running:
|
||||
raise RuntimeError("Agent runtime not started. Call start() first.")
|
||||
return await self._runtime.trigger(
|
||||
entry_point, input_data, correlation_id, session_state=session_state
|
||||
)
|
||||
|
||||
async def trigger_and_wait(
|
||||
self,
|
||||
entry_point: str,
|
||||
input_data: dict,
|
||||
timeout: float | None = None,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult | None:
|
||||
"""
|
||||
Trigger execution and wait for completion.
|
||||
|
||||
Args:
|
||||
entry_point: Entry point ID
|
||||
input_data: Input data for the execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
session_state: Optional session state to resume from (with paused_at, memory)
|
||||
|
||||
Returns:
|
||||
ExecutionResult or None if timeout
|
||||
"""
|
||||
if self._runtime is None or not self._runtime.is_running:
|
||||
raise RuntimeError("Agent runtime not started. Call start() first.")
|
||||
return await self._runtime.trigger_and_wait(
|
||||
entry_point, input_data, timeout, session_state=session_state
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, context: dict, mock_mode=False, session_state=None
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Run the agent (convenience method for simple single execution).
|
||||
|
||||
For more control, use start() + trigger_and_wait() + stop().
|
||||
"""
|
||||
await self.start(mock_mode=mock_mode)
|
||||
try:
|
||||
# Determine entry point based on session_state
|
||||
if session_state and "paused_at" in session_state:
|
||||
paused_node = session_state["paused_at"]
|
||||
resume_key = f"{paused_node}_resume"
|
||||
if resume_key in self.entry_points:
|
||||
entry_point = resume_key
|
||||
else:
|
||||
entry_point = "start"
|
||||
else:
|
||||
entry_point = "start"
|
||||
|
||||
result = await self.trigger_and_wait(
|
||||
entry_point, context, session_state=session_state
|
||||
)
|
||||
return result or ExecutionResult(success=False, error="Execution timeout")
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def get_goal_progress(self) -> dict:
|
||||
"""Get goal progress across all executions."""
|
||||
if self._runtime is None:
|
||||
raise RuntimeError("Agent runtime not started")
|
||||
return await self._runtime.get_goal_progress()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get runtime statistics."""
|
||||
if self._runtime is None:
|
||||
return {"running": False}
|
||||
return self._runtime.get_stats()
|
||||
|
||||
def info(self):
|
||||
"""Get agent information."""
|
||||
return {
|
||||
"name": metadata.name,
|
||||
"version": metadata.version,
|
||||
"description": metadata.description,
|
||||
"goal": {
|
||||
"name": self.goal.name,
|
||||
"description": self.goal.description,
|
||||
},
|
||||
"nodes": [n.id for n in self.nodes],
|
||||
"edges": [e.id for e in self.edges],
|
||||
"entry_node": self.entry_node,
|
||||
"entry_points": self.entry_points,
|
||||
"pause_nodes": self.pause_nodes,
|
||||
"terminal_nodes": self.terminal_nodes,
|
||||
"multi_entrypoint": True,
|
||||
}
|
||||
|
||||
def validate(self):
|
||||
"""Validate agent structure."""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
node_ids = {node.id for node in self.nodes}
|
||||
for edge in self.edges:
|
||||
if edge.source not in node_ids:
|
||||
errors.append(f"Edge {edge.id}: source '{edge.source}' not found")
|
||||
if edge.target not in node_ids:
|
||||
errors.append(f"Edge {edge.id}: target '{edge.target}' not found")
|
||||
|
||||
if self.entry_node not in node_ids:
|
||||
errors.append(f"Entry node '{self.entry_node}' not found")
|
||||
|
||||
for terminal in self.terminal_nodes:
|
||||
if terminal not in node_ids:
|
||||
errors.append(f"Terminal node '{terminal}' not found")
|
||||
|
||||
for pause in self.pause_nodes:
|
||||
if pause not in node_ids:
|
||||
errors.append(f"Pause node '{pause}' not found")
|
||||
|
||||
# Validate entry points
|
||||
for ep_id, node_id in self.entry_points.items():
|
||||
if node_id not in node_ids:
|
||||
errors.append(
|
||||
f"Entry point '{ep_id}' references unknown node '{node_id}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
# Create default instance
|
||||
default_agent = OnlineResearchAgent()
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Runtime configuration."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_preferred_model() -> str:
|
||||
"""Load preferred model from ~/.hive/configuration.json."""
|
||||
config_path = Path.home() / ".hive" / "configuration.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
llm = config.get("llm", {})
|
||||
if llm.get("provider") and llm.get("model"):
|
||||
return f"{llm['provider']}/{llm['model']}"
|
||||
except Exception:
|
||||
pass
|
||||
return "anthropic/claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
model: str = field(default_factory=_load_preferred_model)
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 8192
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
|
||||
|
||||
# Agent metadata
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
name: str = "Online Research Agent"
|
||||
version: str = "1.0.0"
|
||||
description: str = "Research any topic by searching multiple sources, synthesizing information, and producing a well-structured narrative report with citations."
|
||||
|
||||
|
||||
metadata = AgentMetadata()
|
||||
+9
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"hive-tools": {
|
||||
"transport": "stdio",
|
||||
"command": "python",
|
||||
"args": ["mcp_server.py", "--stdio"],
|
||||
"cwd": "../../tools",
|
||||
"description": "Hive tools MCP server providing web_search, web_scrape, and write_to_file"
|
||||
}
|
||||
}
|
||||
+396
@@ -0,0 +1,396 @@
|
||||
"""Node definitions for Online Research Agent."""
|
||||
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
# Node 1: Parse Query
|
||||
parse_query_node = NodeSpec(
|
||||
id="parse-query",
|
||||
name="Parse Query",
|
||||
description="Analyze the research topic and generate 3-5 diverse search queries to cover different aspects",
|
||||
node_type="llm_generate",
|
||||
input_keys=["topic"],
|
||||
output_keys=["search_queries", "research_focus", "key_aspects"],
|
||||
output_schema={
|
||||
"research_focus": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Brief statement of what we're researching",
|
||||
},
|
||||
"key_aspects": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of 3-5 key aspects to investigate",
|
||||
},
|
||||
"search_queries": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of 3-5 search queries",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research query strategist. Given a research topic, analyze it and generate search queries.
|
||||
|
||||
Your task:
|
||||
1. Understand the core research question
|
||||
2. Identify 3-5 key aspects to investigate
|
||||
3. Generate 3-5 diverse search queries that will find comprehensive information
|
||||
|
||||
CRITICAL: Return ONLY raw JSON. NO markdown, NO code blocks.
|
||||
|
||||
Return this JSON structure:
|
||||
{
|
||||
"research_focus": "Brief statement of what we're researching",
|
||||
"key_aspects": ["aspect1", "aspect2", "aspect3"],
|
||||
"search_queries": [
|
||||
"query 1 - broad overview",
|
||||
"query 2 - specific angle",
|
||||
"query 3 - recent developments",
|
||||
"query 4 - expert opinions",
|
||||
"query 5 - data/statistics"
|
||||
]
|
||||
}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Node 2: Search Sources
|
||||
search_sources_node = NodeSpec(
|
||||
id="search-sources",
|
||||
name="Search Sources",
|
||||
description="Execute web searches using the generated queries to find 15+ source URLs",
|
||||
node_type="llm_tool_use",
|
||||
input_keys=["search_queries", "research_focus"],
|
||||
output_keys=["source_urls", "search_results_summary"],
|
||||
output_schema={
|
||||
"source_urls": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of source URLs found",
|
||||
},
|
||||
"search_results_summary": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Brief summary of what was found",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research assistant executing web searches. Use the web_search tool to find sources.
|
||||
|
||||
Your task:
|
||||
1. Execute each search query using web_search tool
|
||||
2. Collect URLs from search results
|
||||
3. Aim for 15+ diverse sources
|
||||
|
||||
After searching, return JSON with found sources:
|
||||
{
|
||||
"source_urls": ["url1", "url2", ...],
|
||||
"search_results_summary": "Brief summary of what was found"
|
||||
}
|
||||
""",
|
||||
tools=["web_search"],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Node 3: Fetch Content
|
||||
fetch_content_node = NodeSpec(
|
||||
id="fetch-content",
|
||||
name="Fetch Content",
|
||||
description="Fetch and extract content from the discovered source URLs",
|
||||
node_type="llm_tool_use",
|
||||
input_keys=["source_urls", "research_focus"],
|
||||
output_keys=["fetched_sources", "fetch_errors"],
|
||||
output_schema={
|
||||
"fetched_sources": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of fetched source objects with url, title, content",
|
||||
},
|
||||
"fetch_errors": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of URLs that failed to fetch",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a content fetcher. Use web_scrape tool to retrieve content from URLs.
|
||||
|
||||
Your task:
|
||||
1. Fetch content from each source URL using web_scrape tool
|
||||
2. Extract the main content relevant to the research focus
|
||||
3. Track any URLs that failed to fetch
|
||||
|
||||
After fetching, return JSON:
|
||||
{
|
||||
"fetched_sources": [
|
||||
{"url": "...", "title": "...", "content": "extracted text..."},
|
||||
...
|
||||
],
|
||||
"fetch_errors": ["url that failed", ...]
|
||||
}
|
||||
""",
|
||||
tools=["web_scrape"],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Node 4: Evaluate Sources
|
||||
evaluate_sources_node = NodeSpec(
|
||||
id="evaluate-sources",
|
||||
name="Evaluate Sources",
|
||||
description="Score sources for relevance and quality, filter to top 10",
|
||||
node_type="llm_generate",
|
||||
input_keys=["fetched_sources", "research_focus", "key_aspects"],
|
||||
output_keys=["ranked_sources", "source_analysis"],
|
||||
output_schema={
|
||||
"ranked_sources": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of ranked sources with scores",
|
||||
},
|
||||
"source_analysis": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Overview of source quality and coverage",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a source evaluator. Assess each source for quality and relevance.
|
||||
|
||||
Scoring criteria:
|
||||
- Relevance to research focus (1-10)
|
||||
- Source credibility (1-10)
|
||||
- Information depth (1-10)
|
||||
- Recency if relevant (1-10)
|
||||
|
||||
Your task:
|
||||
1. Score each source
|
||||
2. Rank by combined score
|
||||
3. Select top 10 sources
|
||||
4. Note what each source uniquely contributes
|
||||
|
||||
Return JSON:
|
||||
{
|
||||
"ranked_sources": [
|
||||
{"url": "...", "title": "...", "content": "...", "score": 8.5, "unique_value": "..."},
|
||||
...
|
||||
],
|
||||
"source_analysis": "Overview of source quality and coverage"
|
||||
}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Node 5: Synthesize Findings
|
||||
synthesize_findings_node = NodeSpec(
|
||||
id="synthesize-findings",
|
||||
name="Synthesize Findings",
|
||||
description="Extract key facts from sources and identify common themes",
|
||||
node_type="llm_generate",
|
||||
input_keys=["ranked_sources", "research_focus", "key_aspects"],
|
||||
output_keys=["key_findings", "themes", "source_citations"],
|
||||
output_schema={
|
||||
"key_findings": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of key findings with sources and confidence",
|
||||
},
|
||||
"themes": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of themes with descriptions and supporting sources",
|
||||
},
|
||||
"source_citations": {
|
||||
"type": "object",
|
||||
"required": True,
|
||||
"description": "Map of facts to supporting URLs",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research synthesizer. Analyze multiple sources to extract insights.
|
||||
|
||||
Your task:
|
||||
1. Identify key facts from each source
|
||||
2. Find common themes across sources
|
||||
3. Note contradictions or debates
|
||||
4. Build a citation map (fact -> source URL)
|
||||
|
||||
Return JSON:
|
||||
{
|
||||
"key_findings": [
|
||||
{"finding": "...", "sources": ["url1", "url2"], "confidence": "high/medium/low"},
|
||||
...
|
||||
],
|
||||
"themes": [
|
||||
{"theme": "...", "description": "...", "supporting_sources": ["url1", ...]},
|
||||
...
|
||||
],
|
||||
"source_citations": {
|
||||
"fact or claim": ["supporting url1", "url2"],
|
||||
...
|
||||
}
|
||||
}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Node 6: Write Report
|
||||
write_report_node = NodeSpec(
|
||||
id="write-report",
|
||||
name="Write Report",
|
||||
description="Generate a narrative report with proper citations",
|
||||
node_type="llm_generate",
|
||||
input_keys=[
|
||||
"key_findings",
|
||||
"themes",
|
||||
"source_citations",
|
||||
"research_focus",
|
||||
"ranked_sources",
|
||||
],
|
||||
output_keys=["report_content", "references"],
|
||||
output_schema={
|
||||
"report_content": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Full markdown report text with citations",
|
||||
},
|
||||
"references": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of reference objects with number, url, title",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research report writer. Create a well-structured narrative report.
|
||||
|
||||
Report structure:
|
||||
1. Executive Summary (2-3 paragraphs)
|
||||
2. Introduction (context and scope)
|
||||
3. Key Findings (organized by theme)
|
||||
4. Analysis (synthesis and implications)
|
||||
5. Conclusion
|
||||
6. References (numbered list of all sources)
|
||||
|
||||
Citation format: Use numbered citations like [1], [2] that correspond to the References section.
|
||||
|
||||
IMPORTANT:
|
||||
- Every factual claim MUST have a citation
|
||||
- Write in clear, professional prose
|
||||
- Be objective and balanced
|
||||
- Highlight areas of consensus and debate
|
||||
|
||||
Return JSON:
|
||||
{
|
||||
"report_content": "Full markdown report text with citations...",
|
||||
"references": [
|
||||
{"number": 1, "url": "...", "title": "..."},
|
||||
...
|
||||
]
|
||||
}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Node 7: Quality Check
|
||||
quality_check_node = NodeSpec(
|
||||
id="quality-check",
|
||||
name="Quality Check",
|
||||
description="Verify all claims have citations and report is coherent",
|
||||
node_type="llm_generate",
|
||||
input_keys=["report_content", "references", "source_citations"],
|
||||
output_keys=["quality_score", "issues", "final_report"],
|
||||
output_schema={
|
||||
"quality_score": {
|
||||
"type": "number",
|
||||
"required": True,
|
||||
"description": "Quality score 0-1",
|
||||
},
|
||||
"issues": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of issues found and fixed",
|
||||
},
|
||||
"final_report": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Corrected full report",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a quality assurance reviewer. Check the research report for issues.
|
||||
|
||||
Check for:
|
||||
1. Uncited claims (factual statements without [n] citation)
|
||||
2. Broken citations (references to non-existent numbers)
|
||||
3. Coherence (logical flow between sections)
|
||||
4. Completeness (all key aspects covered)
|
||||
5. Accuracy (claims match source content)
|
||||
|
||||
If issues found, fix them in the final report.
|
||||
|
||||
Return JSON:
|
||||
{
|
||||
"quality_score": 0.95,
|
||||
"issues": [
|
||||
{"type": "uncited_claim", "location": "paragraph 3", "fixed": true},
|
||||
...
|
||||
],
|
||||
"final_report": "Corrected full report with all issues fixed..."
|
||||
}
|
||||
""",
|
||||
tools=[],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Node 8: Save Report
|
||||
save_report_node = NodeSpec(
|
||||
id="save-report",
|
||||
name="Save Report",
|
||||
description="Write the final report to a local markdown file",
|
||||
node_type="llm_tool_use",
|
||||
input_keys=["final_report", "references", "research_focus"],
|
||||
output_keys=["file_path", "save_status"],
|
||||
output_schema={
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Path where report was saved",
|
||||
},
|
||||
"save_status": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Status of save operation",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a file manager. Save the research report to disk.
|
||||
|
||||
Your task:
|
||||
1. Generate a filename from the research focus (slugified, with date)
|
||||
2. Use the write_to_file tool to save the report as markdown
|
||||
3. Save to the ./research_reports/ directory
|
||||
|
||||
Filename format: research_YYYY-MM-DD_topic-slug.md
|
||||
|
||||
Return JSON:
|
||||
{
|
||||
"file_path": "research_reports/research_2026-01-23_topic-name.md",
|
||||
"save_status": "success"
|
||||
}
|
||||
""",
|
||||
tools=["write_to_file"],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"parse_query_node",
|
||||
"search_sources_node",
|
||||
"fetch_content_node",
|
||||
"evaluate_sources_node",
|
||||
"synthesize_findings_node",
|
||||
"write_report_node",
|
||||
"quality_check_node",
|
||||
"save_report_node",
|
||||
]
|
||||
@@ -0,0 +1,572 @@
|
||||
---
|
||||
name: setup-credentials
|
||||
description: Set up and install credentials for an agent. Detects missing credentials from agent config, collects them from the user, and stores them securely in the encrypted credential store at ~/.hive/credentials.
|
||||
license: Apache-2.0
|
||||
metadata:
|
||||
author: hive
|
||||
version: "2.1"
|
||||
type: utility
|
||||
---
|
||||
|
||||
# Setup Credentials
|
||||
|
||||
Interactive credential setup for agents with multiple authentication options. Detects what's missing, offers auth method choices, validates with health checks, and stores credentials securely.
|
||||
|
||||
## When to Use
|
||||
|
||||
- Before running or testing an agent for the first time
|
||||
- When `AgentRunner.run()` fails with "missing required credentials"
|
||||
- When a user asks to configure credentials for an agent
|
||||
- After building a new agent that uses tools requiring API keys
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Identify the Agent
|
||||
|
||||
Determine which agent needs credentials. The user will either:
|
||||
|
||||
- Name the agent directly (e.g., "set up credentials for hubspot-agent")
|
||||
- Have an agent directory open (check `exports/` for agent dirs)
|
||||
- Be working on an agent in the current session
|
||||
|
||||
Locate the agent's directory under `exports/{agent_name}/`.
|
||||
|
||||
### Step 2: Detect Required Credentials
|
||||
|
||||
Read the agent's configuration to determine which tools and node types it uses:
|
||||
|
||||
```python
|
||||
from core.framework.runner import AgentRunner
|
||||
|
||||
runner = AgentRunner.load("exports/{agent_name}")
|
||||
validation = runner.validate()
|
||||
|
||||
# validation.missing_credentials contains env var names
|
||||
# validation.warnings contains detailed messages with help URLs
|
||||
```
|
||||
|
||||
Alternatively, check the credential store directly:
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore
|
||||
|
||||
# Use encrypted storage (default: ~/.hive/credentials)
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
|
||||
# Check what's available
|
||||
available = store.list_credentials()
|
||||
print(f"Available credentials: {available}")
|
||||
|
||||
# Check if specific credential exists
|
||||
if store.is_available("hubspot"):
|
||||
print("HubSpot credential found")
|
||||
else:
|
||||
print("HubSpot credential missing")
|
||||
```
|
||||
|
||||
To see all known credential specs (for help URLs and setup instructions):
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
print(f"{name}: env_var={spec.env_var}, aden={spec.aden_supported}")
|
||||
```
|
||||
|
||||
### Step 3: Present Auth Options for Each Missing Credential
|
||||
|
||||
For each missing credential, check what authentication methods are available:
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS.get("hubspot")
|
||||
if spec:
|
||||
# Determine available auth options
|
||||
auth_options = []
|
||||
if spec.aden_supported:
|
||||
auth_options.append("aden")
|
||||
if spec.direct_api_key_supported:
|
||||
auth_options.append("direct")
|
||||
auth_options.append("custom") # Always available
|
||||
|
||||
# Get setup info
|
||||
setup_info = {
|
||||
"env_var": spec.env_var,
|
||||
"description": spec.description,
|
||||
"help_url": spec.help_url,
|
||||
"api_key_instructions": spec.api_key_instructions,
|
||||
}
|
||||
```
|
||||
|
||||
Present the available options using AskUserQuestion:
|
||||
|
||||
```
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
|
||||
1) Aden Authorization Server (Recommended)
|
||||
Secure OAuth2 flow via integration.adenhq.com
|
||||
- Quick setup with automatic token refresh
|
||||
- No need to manage API keys manually
|
||||
|
||||
2) Direct API Key
|
||||
Enter your own API key manually
|
||||
- Requires creating a HubSpot Private App
|
||||
- Full control over scopes and permissions
|
||||
|
||||
3) Custom Credential Store (Advanced)
|
||||
Programmatic configuration for CI/CD
|
||||
- For automated deployments
|
||||
- Requires manual API calls
|
||||
```
|
||||
|
||||
### Step 4: Execute Auth Flow Based on User Choice
|
||||
|
||||
#### Option 1: Aden Authorization Server
|
||||
|
||||
This is the recommended flow for supported integrations (HubSpot, etc.).
|
||||
|
||||
**How Aden OAuth Works:**
|
||||
|
||||
The ADEN_API_KEY represents a user who has already completed OAuth authorization on Aden's platform. When users sign up and connect integrations on Aden, those OAuth tokens are stored server-side. Having an ADEN_API_KEY means:
|
||||
|
||||
1. User has an Aden account
|
||||
2. User has already authorized integrations (HubSpot, etc.) via OAuth on Aden
|
||||
3. We just need to sync those credentials down to the local credential store
|
||||
|
||||
**4.1a. Check for ADEN_API_KEY**
|
||||
|
||||
```python
|
||||
import os
|
||||
aden_key = os.environ.get("ADEN_API_KEY")
|
||||
```
|
||||
|
||||
If not set, guide user to get one from Aden (this is where they do OAuth):
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import open_browser, get_aden_setup_url
|
||||
|
||||
# Open browser to Aden - user will sign up and connect integrations there
|
||||
url = get_aden_setup_url() # https://integration.adenhq.com/setup
|
||||
success, msg = open_browser(url)
|
||||
|
||||
print("Please sign in to Aden and connect your integrations (HubSpot, etc.).")
|
||||
print("Once done, copy your API key and return here.")
|
||||
```
|
||||
|
||||
Ask user to provide the ADEN_API_KEY they received.
|
||||
|
||||
**4.1b. Save ADEN_API_KEY to Shell Config**
|
||||
|
||||
With user approval, persist ADEN_API_KEY to their shell config:
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import (
|
||||
detect_shell,
|
||||
add_env_var_to_shell_config,
|
||||
get_shell_source_command,
|
||||
)
|
||||
|
||||
shell_type = detect_shell() # 'bash', 'zsh', or 'unknown'
|
||||
|
||||
# Ask user for approval before modifying shell config
|
||||
# If approved:
|
||||
success, config_path = add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
user_provided_key,
|
||||
comment="Aden authorization server API key"
|
||||
)
|
||||
|
||||
if success:
|
||||
source_cmd = get_shell_source_command()
|
||||
print(f"Saved to {config_path}")
|
||||
print(f"Run: {source_cmd}")
|
||||
```
|
||||
|
||||
Also save to `~/.hive/configuration.json` for the framework:
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
config_path = Path.home() / ".hive" / "configuration.json"
|
||||
config = json.loads(config_path.read_text()) if config_path.exists() else {}
|
||||
|
||||
config["aden"] = {
|
||||
"api_key_configured": True,
|
||||
"api_url": "https://api.adenhq.com"
|
||||
}
|
||||
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_path.write_text(json.dumps(config, indent=2))
|
||||
```
|
||||
|
||||
**4.1c. Sync Credentials from Aden Server**
|
||||
|
||||
Since the user has already authorized integrations on Aden, use the one-liner factory method:
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore
|
||||
|
||||
# This single call handles everything:
|
||||
# - Creates encrypted local storage at ~/.hive/credentials
|
||||
# - Configures Aden client from ADEN_API_KEY env var
|
||||
# - Syncs all credentials from Aden server automatically
|
||||
store = CredentialStore.with_aden_sync(
|
||||
base_url="https://api.adenhq.com",
|
||||
auto_sync=True, # Syncs on creation
|
||||
)
|
||||
|
||||
# Check what was synced
|
||||
synced = store.list_credentials()
|
||||
print(f"Synced credentials: {synced}")
|
||||
|
||||
# If the required credential wasn't synced, the user hasn't authorized it on Aden yet
|
||||
if "hubspot" not in synced:
|
||||
print("HubSpot not found in your Aden account.")
|
||||
print("Please visit https://integration.adenhq.com to connect HubSpot, then try again.")
|
||||
```
|
||||
|
||||
For more control over the sync process:
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# Create client (API key loaded from ADEN_API_KEY env var)
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url="https://api.adenhq.com",
|
||||
))
|
||||
|
||||
# Create provider and store
|
||||
provider = AdenSyncProvider(client=client)
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
|
||||
# Manual sync
|
||||
synced_count = provider.sync_all(store)
|
||||
print(f"Synced {synced_count} credentials from Aden")
|
||||
```
|
||||
|
||||
**4.1d. Run Health Check**
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import check_credential_health
|
||||
|
||||
# Get the token from the store
|
||||
cred = store.get_credential("hubspot")
|
||||
token = cred.keys["access_token"].value.get_secret_value()
|
||||
|
||||
result = check_credential_health("hubspot", token)
|
||||
if result.valid:
|
||||
print("HubSpot credentials validated successfully!")
|
||||
else:
|
||||
print(f"Validation failed: {result.message}")
|
||||
# Offer to retry the OAuth flow
|
||||
```
|
||||
|
||||
#### Option 2: Direct API Key
|
||||
|
||||
For users who prefer manual API key management.
|
||||
|
||||
**4.2a. Show Setup Instructions**
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS.get("hubspot")
|
||||
if spec and spec.api_key_instructions:
|
||||
print(spec.api_key_instructions)
|
||||
# Output:
|
||||
# To get a HubSpot Private App token:
|
||||
# 1. Go to HubSpot Settings > Integrations > Private Apps
|
||||
# 2. Click "Create a private app"
|
||||
# 3. Name your app (e.g., "Hive Agent")
|
||||
# ...
|
||||
|
||||
if spec and spec.help_url:
|
||||
print(f"More info: {spec.help_url}")
|
||||
```
|
||||
|
||||
**4.2b. Collect API Key from User**
|
||||
|
||||
Use AskUserQuestion to securely collect the API key:
|
||||
|
||||
```
|
||||
Please provide your HubSpot access token:
|
||||
(This will be stored securely in ~/.hive/credentials)
|
||||
```
|
||||
|
||||
**4.2c. Run Health Check Before Storing**
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import check_credential_health
|
||||
|
||||
result = check_credential_health("hubspot", user_provided_token)
|
||||
if not result.valid:
|
||||
print(f"Warning: {result.message}")
|
||||
# Ask user if they want to:
|
||||
# 1. Try a different token
|
||||
# 2. Continue anyway (not recommended)
|
||||
```
|
||||
|
||||
**4.2d. Store in Encrypted Credential Store**
|
||||
|
||||
```python
|
||||
from core.framework.credentials import CredentialStore, CredentialObject, CredentialKey
|
||||
from pydantic import SecretStr
|
||||
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
name="HubSpot Access Token",
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(user_provided_token),
|
||||
)
|
||||
},
|
||||
)
|
||||
store.save_credential(cred)
|
||||
```
|
||||
|
||||
**4.2e. Export to Current Session**
|
||||
|
||||
```bash
|
||||
export HUBSPOT_ACCESS_TOKEN="the-value"
|
||||
```
|
||||
|
||||
#### Option 3: Custom Credential Store (Advanced)
|
||||
|
||||
For programmatic/CI/CD setups.
|
||||
|
||||
**4.3a. Show Documentation**
|
||||
|
||||
```
|
||||
For advanced credential management, you can use the CredentialStore API directly:
|
||||
|
||||
from core.framework.credentials import CredentialStore, CredentialObject, CredentialKey
|
||||
from pydantic import SecretStr
|
||||
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
name="HubSpot Access Token",
|
||||
keys={"access_token": CredentialKey(name="access_token", value=SecretStr("..."))}
|
||||
)
|
||||
store.save_credential(cred)
|
||||
|
||||
For CI/CD environments:
|
||||
- Set HIVE_CREDENTIAL_KEY for encryption
|
||||
- Pre-populate ~/.hive/credentials programmatically
|
||||
- Or use environment variables directly (HUBSPOT_ACCESS_TOKEN)
|
||||
|
||||
Documentation: See core/framework/credentials/README.md
|
||||
```
|
||||
|
||||
### Step 5: Record Configuration Method
|
||||
|
||||
Track which auth method was used for each credential in `~/.hive/configuration.json`:
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
config_path = Path.home() / ".hive" / "configuration.json"
|
||||
config = json.loads(config_path.read_text()) if config_path.exists() else {}
|
||||
|
||||
if "credential_methods" not in config:
|
||||
config["credential_methods"] = {}
|
||||
|
||||
config["credential_methods"]["hubspot"] = {
|
||||
"method": "aden", # or "direct" or "custom"
|
||||
"configured_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
config_path.write_text(json.dumps(config, indent=2))
|
||||
```
|
||||
|
||||
### Step 6: Verify All Credentials
|
||||
|
||||
Run validation again to confirm everything is set:
|
||||
|
||||
```python
|
||||
runner = AgentRunner.load("exports/{agent_name}")
|
||||
validation = runner.validate()
|
||||
assert not validation.missing_credentials, "Still missing credentials!"
|
||||
```
|
||||
|
||||
Report the result to the user.
|
||||
|
||||
## Health Check Reference
|
||||
|
||||
Health checks validate credentials by making lightweight API calls:
|
||||
|
||||
| Credential | Endpoint | What It Checks |
|
||||
| -------------- | --------------------------------------- | --------------------------------- |
|
||||
| `hubspot` | `GET /crm/v3/objects/contacts?limit=1` | Bearer token validity, CRM scopes |
|
||||
| `brave_search` | `GET /res/v1/web/search?q=test&count=1` | API key validity |
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import check_credential_health, HealthCheckResult
|
||||
|
||||
result: HealthCheckResult = check_credential_health("hubspot", token_value)
|
||||
# result.valid: bool
|
||||
# result.message: str
|
||||
# result.details: dict (status_code, rate_limited, etc.)
|
||||
```
|
||||
|
||||
## Encryption Key (HIVE_CREDENTIAL_KEY)
|
||||
|
||||
The encrypted credential store requires `HIVE_CREDENTIAL_KEY` to encrypt/decrypt credentials.
|
||||
|
||||
- If the user doesn't have one, `EncryptedFileStorage` will auto-generate one and log it
|
||||
- The user MUST persist this key (e.g., in `~/.bashrc` or a secrets manager)
|
||||
- Without this key, stored credentials cannot be decrypted
|
||||
- This is the ONLY secret that should live in `~/.bashrc` or environment config
|
||||
|
||||
If `HIVE_CREDENTIAL_KEY` is not set:
|
||||
|
||||
1. Let the store generate one
|
||||
2. Tell the user to save it: `export HIVE_CREDENTIAL_KEY="{generated_key}"`
|
||||
3. Recommend adding it to `~/.bashrc` or their shell profile
|
||||
|
||||
## Security Rules
|
||||
|
||||
- **NEVER** log, print, or echo credential values in tool output
|
||||
- **NEVER** store credentials in plaintext files, git-tracked files, or agent configs
|
||||
- **NEVER** hardcode credentials in source code
|
||||
- **ALWAYS** use `SecretStr` from Pydantic when handling credential values in Python
|
||||
- **ALWAYS** use the encrypted credential store (`~/.hive/credentials`) for persistence
|
||||
- **ALWAYS** run health checks before storing credentials (when possible)
|
||||
- **ALWAYS** verify credentials were stored by re-running validation, not by reading them back
|
||||
- When modifying `~/.bashrc` or `~/.zshrc`, confirm with the user first
|
||||
|
||||
## Credential Sources Reference
|
||||
|
||||
All credential specs are defined in `tools/src/aden_tools/credentials/`:
|
||||
|
||||
| File | Category | Credentials | Aden Supported |
|
||||
| ----------------- | ------------- | --------------------------------------------- | -------------- |
|
||||
| `llm.py` | LLM Providers | `anthropic` | No |
|
||||
| `search.py` | Search Tools | `brave_search`, `google_search`, `google_cse` | No |
|
||||
| `integrations.py` | Integrations | `hubspot` | Yes |
|
||||
|
||||
**Note:** Additional LLM providers (Cerebras, Groq, OpenAI) are handled by LiteLLM via environment
|
||||
variables (`CEREBRAS_API_KEY`, `GROQ_API_KEY`, `OPENAI_API_KEY`) but are not yet in CREDENTIAL_SPECS.
|
||||
Add them to `llm.py` as needed.
|
||||
|
||||
To check what's registered:
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
for name, spec in CREDENTIAL_SPECS.items():
|
||||
print(f"{name}: aden={spec.aden_supported}, direct={spec.direct_api_key_supported}")
|
||||
```
|
||||
|
||||
## Migration: CredentialManager → CredentialStore
|
||||
|
||||
**CredentialManager is deprecated.** Use CredentialStore instead.
|
||||
|
||||
| Old (Deprecated) | New (Recommended) |
|
||||
| ----------------------------------------- | -------------------------------------------------------------------- |
|
||||
| `CredentialManager()` | `CredentialStore.with_encrypted_storage()` |
|
||||
| `creds.get("hubspot")` | `store.get("hubspot")` or `store.get_key("hubspot", "access_token")` |
|
||||
| `creds.validate_for_tools(tools)` | Use `store.is_available(cred_id)` per credential |
|
||||
| `creds.get_auth_options("hubspot")` | Check `CREDENTIAL_SPECS["hubspot"].aden_supported` |
|
||||
| `creds.get_setup_instructions("hubspot")` | Access `CREDENTIAL_SPECS["hubspot"]` directly |
|
||||
|
||||
**Why migrate?**
|
||||
|
||||
- **CredentialStore** supports encrypted storage, multi-key credentials, template resolution, and automatic token refresh
|
||||
- **CredentialManager** only reads from environment variables and .env files (no encryption, no refresh)
|
||||
- **CredentialStoreAdapter** exists for backward compatibility during migration
|
||||
|
||||
```python
|
||||
# Old way (deprecated)
|
||||
from aden_tools.credentials import CredentialManager
|
||||
creds = CredentialManager()
|
||||
token = creds.get("hubspot")
|
||||
|
||||
# New way (recommended)
|
||||
from core.framework.credentials import CredentialStore
|
||||
store = CredentialStore.with_encrypted_storage()
|
||||
token = store.get("hubspot")
|
||||
|
||||
# With Aden sync (recommended for OAuth integrations)
|
||||
store = CredentialStore.with_aden_sync()
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
```
|
||||
|
||||
## Example Session
|
||||
|
||||
```
|
||||
User: /setup-credentials for my hubspot-agent
|
||||
|
||||
Agent: Let me check what credentials your hubspot-agent needs.
|
||||
|
||||
[Runs validation, finds ANTHROPIC_API_KEY and HUBSPOT_ACCESS_TOKEN missing]
|
||||
|
||||
Agent: Your hubspot-agent requires 2 credentials:
|
||||
|
||||
1. ANTHROPIC_API_KEY - Anthropic API key for LLM calls
|
||||
2. HUBSPOT_ACCESS_TOKEN - HubSpot access token
|
||||
|
||||
Let's set these up one at a time.
|
||||
|
||||
--- Setting up ANTHROPIC_API_KEY ---
|
||||
|
||||
This credential only supports direct API key entry.
|
||||
|
||||
[Shows instructions for getting Anthropic API key]
|
||||
[AskUserQuestion: "Please provide your Anthropic API key:"]
|
||||
[User provides key]
|
||||
[Stores in credential store, exports to env]
|
||||
|
||||
--- Setting up HUBSPOT_ACCESS_TOKEN ---
|
||||
|
||||
[AskUserQuestion with 3 options]
|
||||
|
||||
Choose how to configure HUBSPOT_ACCESS_TOKEN:
|
||||
1) Aden Authorization Server (Recommended)
|
||||
2) Direct API Key
|
||||
3) Custom Credential Store (Advanced)
|
||||
|
||||
User: 1 (Aden)
|
||||
|
||||
Agent: Great! Let me check if you have an Aden API key configured...
|
||||
|
||||
[Checks for ADEN_API_KEY, not found]
|
||||
[Opens browser to integration.adenhq.com/setup]
|
||||
|
||||
Agent: I've opened Aden's setup page. Please:
|
||||
1. Sign in or create an account
|
||||
2. Connect your HubSpot account (OAuth happens on Aden's site)
|
||||
3. Copy the API key shown after setup
|
||||
|
||||
[AskUserQuestion: "Please provide your Aden API key:"]
|
||||
[User provides key]
|
||||
|
||||
Agent: Would you like me to save this to your shell config (~/.zshrc) for future sessions?
|
||||
|
||||
[User approves]
|
||||
[Saves to ~/.zshrc and ~/.hive/configuration.json]
|
||||
|
||||
Agent: Let me sync your credentials from Aden...
|
||||
|
||||
[Syncs credentials from Aden server - OAuth already done on Aden's side]
|
||||
[Runs health check]
|
||||
|
||||
Agent: HubSpot credentials validated successfully!
|
||||
|
||||
All credentials are now configured:
|
||||
- ANTHROPIC_API_KEY: Stored in encrypted credential store
|
||||
- HUBSPOT_ACCESS_TOKEN: Synced from Aden (OAuth completed on Aden)
|
||||
- Validation passed - your agent is ready to run!
|
||||
```
|
||||
@@ -3,64 +3,53 @@ name: testing-agent
|
||||
description: Run goal-based evaluation tests for agents. Use when you need to verify an agent meets its goals, debug failing tests, or iterate on agent improvements based on test results.
|
||||
---
|
||||
|
||||
# ⛔ MANDATORY: USE MCP TOOLS ONLY
|
||||
# Testing Workflow
|
||||
|
||||
**STOP. Read this before doing anything else.**
|
||||
This skill provides tools for testing agents built with the building-agents skill.
|
||||
|
||||
You MUST use MCP tools for ALL testing operations. Never write test files directly.
|
||||
|
||||
## Required MCP Workflow
|
||||
## Workflow Overview
|
||||
|
||||
1. `mcp__agent-builder__list_tests` - Check what tests exist
|
||||
2. `mcp__agent-builder__generate_constraint_tests` or `mcp__agent-builder__generate_success_tests` - Generate tests
|
||||
3. `mcp__agent-builder__get_pending_tests` - Review pending tests
|
||||
4. `mcp__agent-builder__approve_tests` - Approve tests (this writes the files)
|
||||
5. `mcp__agent-builder__run_tests` - Execute tests
|
||||
6. `mcp__agent-builder__debug_test` - Debug failures
|
||||
2. `mcp__agent-builder__generate_constraint_tests` or `mcp__agent-builder__generate_success_tests` - Get test guidelines
|
||||
3. **Write tests directly** using the Write tool with the guidelines provided
|
||||
4. `mcp__agent-builder__run_tests` - Execute tests
|
||||
5. `mcp__agent-builder__debug_test` - Debug failures
|
||||
|
||||
## ❌ WRONG - Never Do This
|
||||
## How Test Generation Works
|
||||
|
||||
The `generate_*_tests` MCP tools return **guidelines and templates** - they do NOT generate test code via LLM.
|
||||
You (Claude) write the tests directly using the Write tool based on the guidelines.
|
||||
|
||||
### Example Workflow
|
||||
|
||||
```python
|
||||
# WRONG: Writing test file directly with Write tool
|
||||
Write(file_path="exports/agent/tests/test_foo.py", content="def test_...")
|
||||
```
|
||||
|
||||
```python
|
||||
# WRONG: Running pytest directly via Bash
|
||||
Bash(command="pytest exports/agent/tests/ -v")
|
||||
```
|
||||
|
||||
```python
|
||||
# WRONG: Creating test code manually
|
||||
test_code = """
|
||||
def test_something():
|
||||
assert True
|
||||
"""
|
||||
```
|
||||
|
||||
## ✅ CORRECT - Always Do This
|
||||
|
||||
```python
|
||||
# CORRECT: Generate tests via MCP tool
|
||||
mcp__agent-builder__generate_constraint_tests(
|
||||
# Step 1: Get test guidelines
|
||||
result = mcp__agent-builder__generate_constraint_tests(
|
||||
goal_id="my-goal",
|
||||
goal_json='{"id": "...", "constraints": [...]}',
|
||||
agent_path="exports/my_agent"
|
||||
)
|
||||
|
||||
# CORRECT: Approve tests via MCP tool (this writes files)
|
||||
mcp__agent-builder__approve_tests(
|
||||
goal_id="my-goal",
|
||||
approvals='[{"test_id": "test-1", "action": "approve"}]'
|
||||
# Step 2: The result contains:
|
||||
# - output_file: where to write tests
|
||||
# - file_header: imports and fixtures to use
|
||||
# - test_template: format for test functions
|
||||
# - constraints_formatted: the constraints to test
|
||||
# - test_guidelines: rules for writing tests
|
||||
|
||||
# Step 3: Write tests directly using the Write tool
|
||||
Write(
|
||||
file_path=result["output_file"],
|
||||
content=result["file_header"] + test_code_you_write
|
||||
)
|
||||
|
||||
# CORRECT: Run tests via MCP tool
|
||||
# Step 4: Run tests via MCP tool
|
||||
mcp__agent-builder__run_tests(
|
||||
goal_id="my-goal",
|
||||
agent_path="exports/my_agent"
|
||||
)
|
||||
|
||||
# CORRECT: Debug failures via MCP tool
|
||||
# Step 5: Debug failures via MCP tool
|
||||
mcp__agent-builder__debug_test(
|
||||
goal_id="my-goal",
|
||||
test_name="test_constraint_foo",
|
||||
@@ -68,22 +57,15 @@ mcp__agent-builder__debug_test(
|
||||
)
|
||||
```
|
||||
|
||||
## Self-Check Before Every Action
|
||||
|
||||
Before you take any testing action, ask yourself:
|
||||
- Am I about to write `def test_...`? → **STOP, use `generate_*_tests` instead**
|
||||
- Am I about to use `Write` for a test file? → **STOP, use `approve_tests` instead**
|
||||
- Am I about to run `pytest` via Bash? → **STOP, use `run_tests` instead**
|
||||
|
||||
---
|
||||
|
||||
# Testing Agents with MCP Tools
|
||||
|
||||
Run goal-based evaluation tests for agents built with the building-agents skill.
|
||||
|
||||
**Key Principle: Tests are generated via MCP tools and written as Python files**
|
||||
- ✅ Generate tests: `generate_constraint_tests`, `generate_success_tests`
|
||||
- ✅ Review and approve: `get_pending_tests`, `approve_tests` → writes to Python files
|
||||
**Key Principle: MCP tools provide guidelines, Claude writes tests directly**
|
||||
- ✅ Get guidelines: `generate_constraint_tests`, `generate_success_tests` → returns templates and guidelines
|
||||
- ✅ Write tests: Use the Write tool with the provided file_header and test_template
|
||||
- ✅ Run tests: `run_tests` (runs pytest via subprocess)
|
||||
- ✅ Debug failures: `debug_test` (re-runs single test with verbose output)
|
||||
- ✅ List tests: `list_tests` (scans Python test files)
|
||||
@@ -118,39 +100,64 @@ async def test_happy_path(mock_mode):
|
||||
assert len(result.output) > 0
|
||||
```
|
||||
|
||||
## Why MCP Tools Are Required
|
||||
## Why This Approach
|
||||
|
||||
- Tests are generated with proper imports, fixtures, and API key enforcement
|
||||
- Approval workflow ensures user review before file creation
|
||||
- MCP tools provide consistent test guidelines with proper imports, fixtures, and API key enforcement
|
||||
- Claude writes tests directly, eliminating circular LLM dependencies in the MCP server
|
||||
- `run_tests` parses pytest output into structured results for iteration
|
||||
- `debug_test` provides formatted output with actionable debugging info
|
||||
- `conftest.py` is auto-created with proper fixtures
|
||||
- File headers include conftest.py setup with proper fixtures
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. **Check existing tests** - `list_tests(goal_id, agent_path)`
|
||||
2. **Generate test files** - `generate_constraint_tests` or `generate_success_tests`
|
||||
3. **User reviews and approves** - `get_pending_tests` → `approve_tests`
|
||||
2. **Get test guidelines** - `generate_constraint_tests` or `generate_success_tests`
|
||||
3. **Write tests** - Use the Write tool with the provided file_header and guidelines
|
||||
4. **Run tests** - `run_tests(goal_id, agent_path)`
|
||||
5. **Debug failures** - `debug_test(goal_id, test_name, agent_path)`
|
||||
6. **Iterate** - Repeat steps 4-5 until all pass
|
||||
|
||||
## ⚠️ API Key Requirement for Real Testing
|
||||
## ⚠️ Credential Requirements for Testing
|
||||
|
||||
**CRITICAL: Real LLM testing requires an API key.** Mock mode only validates structure and does NOT test actual agent behavior.
|
||||
**CRITICAL: Testing requires ALL credentials the agent depends on.** This includes both the LLM API key AND any tool-specific credentials (HubSpot, Brave Search, etc.).
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before running agent tests, you MUST set your API key:
|
||||
Before running agent tests, you MUST collect ALL required credentials from the user.
|
||||
|
||||
**Step 1: LLM API Key (always required)**
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
**Why API keys are required:**
|
||||
**Step 2: Tool-specific credentials (depends on agent's tools)**
|
||||
|
||||
Inspect the agent's `mcp_servers.json` and tool configuration to determine which tools the agent uses, then check for all required credentials:
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CredentialManager, CREDENTIAL_SPECS
|
||||
|
||||
creds = CredentialManager()
|
||||
|
||||
# Determine which tools the agent uses (from agent.json or mcp_servers.json)
|
||||
agent_tools = [...] # e.g., ["hubspot_search_contacts", "web_search", ...]
|
||||
|
||||
# Find all missing credentials for those tools
|
||||
missing = creds.get_missing_for_tools(agent_tools)
|
||||
```
|
||||
|
||||
Common tool credentials:
|
||||
| Tool | Env Var | Help URL |
|
||||
|------|---------|----------|
|
||||
| HubSpot CRM | `HUBSPOT_ACCESS_TOKEN` | https://developers.hubspot.com/docs/api/private-apps |
|
||||
| Brave Search | `BRAVE_SEARCH_API_KEY` | https://brave.com/search/api/ |
|
||||
| Google Search | `GOOGLE_SEARCH_API_KEY` + `GOOGLE_SEARCH_CX` | https://developers.google.com/custom-search |
|
||||
|
||||
**Why ALL credentials are required:**
|
||||
- Tests need to execute the agent's LLM nodes to validate behavior
|
||||
- Mock mode bypasses LLM calls, providing no confidence in real-world performance
|
||||
- Success criteria (personalization, reasoning quality, constraint adherence) can only be tested with real LLM calls
|
||||
- Tools with missing credentials will return error dicts instead of real data
|
||||
- Mock mode bypasses everything, providing no confidence in real-world performance
|
||||
- The `AgentRunner.run()` method validates credentials at startup and will fail fast if any are missing
|
||||
|
||||
### Mock Mode Limitations
|
||||
|
||||
@@ -164,11 +171,11 @@ Mock mode (`--mock` flag or `mock_mode=True`) is **ONLY for structure validation
|
||||
✗ Does NOT test real API integrations or tool use
|
||||
✗ Does NOT test personalization or content quality
|
||||
|
||||
**Bottom line:** If you're testing whether an agent achieves its goal, you MUST use a real API key.
|
||||
**Bottom line:** If you're testing whether an agent achieves its goal, you MUST use real credentials for ALL services.
|
||||
|
||||
### Enforcing API Key in Tests
|
||||
### Enforcing Credentials in Tests
|
||||
|
||||
When generating tests, **ALWAYS include API key checks**:
|
||||
When generating tests, **ALWAYS include credential checks for ALL required services**:
|
||||
|
||||
```python
|
||||
import os
|
||||
@@ -183,11 +190,14 @@ pytestmark = pytest.mark.skipif(
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def check_api_key():
|
||||
"""Ensure API key is set for real testing."""
|
||||
def check_credentials():
|
||||
"""Ensure ALL required credentials are set for real testing."""
|
||||
creds = CredentialManager()
|
||||
mock_mode = os.environ.get("MOCK_MODE")
|
||||
|
||||
# Always check LLM key
|
||||
if not creds.is_available("anthropic"):
|
||||
if os.environ.get("MOCK_MODE"):
|
||||
if mock_mode:
|
||||
print("\n⚠️ Running in MOCK MODE - structure validation only")
|
||||
print(" This does NOT test LLM behavior or agent quality")
|
||||
print(" Set ANTHROPIC_API_KEY for real testing\n")
|
||||
@@ -201,39 +211,69 @@ def check_api_key():
|
||||
" MOCK_MODE=1 pytest exports/{agent}/tests/\n\n"
|
||||
"Note: Mock mode does NOT validate agent behavior or quality."
|
||||
)
|
||||
|
||||
# Check tool-specific credentials (skip in mock mode)
|
||||
if not mock_mode:
|
||||
# List the tools this agent uses - update per agent
|
||||
agent_tools = [] # e.g., ["hubspot_search_contacts", "hubspot_get_contact"]
|
||||
missing = creds.get_missing_for_tools(agent_tools)
|
||||
if missing:
|
||||
lines = ["\n❌ Missing tool credentials!\n"]
|
||||
for name in missing:
|
||||
spec = creds.specs.get(name)
|
||||
if spec:
|
||||
lines.append(f" {spec.env_var} - {spec.description}")
|
||||
if spec.help_url:
|
||||
lines.append(f" Setup: {spec.help_url}")
|
||||
lines.append("\nSet the required environment variables and re-run.")
|
||||
pytest.fail("\n".join(lines))
|
||||
```
|
||||
|
||||
### User Communication
|
||||
|
||||
When the user asks to test an agent, **ALWAYS check for the API key first**:
|
||||
When the user asks to test an agent, **ALWAYS check for ALL credentials first** — not just the LLM key:
|
||||
|
||||
1. **Identify the agent's tools** from `agent.json` or `mcp_servers.json`
|
||||
2. **Check ALL required credentials** using `CredentialManager`
|
||||
3. **Ask the user to provide any missing credentials** before proceeding
|
||||
|
||||
```python
|
||||
from aden_tools.credentials import CredentialManager
|
||||
from aden_tools.credentials import CredentialManager, CREDENTIAL_SPECS
|
||||
|
||||
# Before running any tests
|
||||
creds = CredentialManager()
|
||||
if not creds.is_available("anthropic"):
|
||||
print("⚠️ No ANTHROPIC_API_KEY found!")
|
||||
print()
|
||||
print("Testing requires a real API key to validate agent behavior.")
|
||||
print()
|
||||
print("Options:")
|
||||
print("1. Set your API key (RECOMMENDED):")
|
||||
print(" export ANTHROPIC_API_KEY='your-key-here'")
|
||||
print()
|
||||
print("2. Run in mock mode (structure validation only):")
|
||||
print(" MOCK_MODE=1 pytest exports/{agent}/tests/")
|
||||
print()
|
||||
print("Mock mode does NOT test:")
|
||||
print(" - LLM message generation")
|
||||
print(" - Reasoning or decision quality")
|
||||
print(" - Constraint validation")
|
||||
print(" - Real API integrations")
|
||||
|
||||
# Ask user what to do
|
||||
# 1. Check LLM key
|
||||
missing_creds = []
|
||||
if not creds.is_available("anthropic"):
|
||||
missing_creds.append(("ANTHROPIC_API_KEY", "Anthropic API key for LLM calls"))
|
||||
|
||||
# 2. Check tool-specific credentials
|
||||
agent_tools = [...] # Determined from agent config
|
||||
missing_tools = creds.get_missing_for_tools(agent_tools)
|
||||
for name in missing_tools:
|
||||
spec = CREDENTIAL_SPECS.get(name)
|
||||
if spec:
|
||||
missing_creds.append((spec.env_var, spec.description))
|
||||
|
||||
# 3. Present ALL missing credentials to the user at once
|
||||
if missing_creds:
|
||||
print("⚠️ Missing credentials required by this agent:\n")
|
||||
for env_var, description in missing_creds:
|
||||
print(f" • {env_var} — {description}")
|
||||
print()
|
||||
print("Please set the missing environment variables:")
|
||||
for env_var, _ in missing_creds:
|
||||
print(f" export {env_var}='your-value-here'")
|
||||
print()
|
||||
print("Or run in mock mode (structure validation only):")
|
||||
print(" MOCK_MODE=1 pytest exports/{agent}/tests/")
|
||||
|
||||
# Ask user to provide credentials or choose mock mode
|
||||
AskUserQuestion(...)
|
||||
```
|
||||
|
||||
**IMPORTANT:** Do NOT skip credential collection. If an agent uses HubSpot tools, the user MUST provide `HUBSPOT_ACCESS_TOKEN`. If it uses web search, the user MUST provide the appropriate search API key. Collect ALL missing credentials in a single prompt rather than discovering them one at a time during test failures.
|
||||
|
||||
## The Three-Stage Flow
|
||||
|
||||
```
|
||||
@@ -284,17 +324,17 @@ This shows what test files already exist. If tests exist:
|
||||
- Review the list to see what's covered
|
||||
- Ask user if they want to add more or run existing tests
|
||||
|
||||
### Step 2: Generate Constraint Tests (Goal Stage)
|
||||
### Step 2: Get Constraint Test Guidelines (Goal Stage)
|
||||
|
||||
After goal is defined, generate constraint tests using the MCP tool:
|
||||
After goal is defined, get test guidelines using the MCP tool:
|
||||
|
||||
```python
|
||||
# First, read the goal from agent.py to get the goal JSON
|
||||
goal_code = Read(file_path="exports/your_agent/agent.py")
|
||||
# Extract the goal definition and convert to JSON
|
||||
|
||||
# Generate constraint tests via MCP tool
|
||||
mcp__agent-builder__generate_constraint_tests(
|
||||
# Get constraint test guidelines via MCP tool
|
||||
result = mcp__agent-builder__generate_constraint_tests(
|
||||
goal_id="your-goal-id",
|
||||
goal_json='{"id": "goal-id", "name": "...", "constraints": [...]}',
|
||||
agent_path="exports/your_agent"
|
||||
@@ -302,37 +342,30 @@ mcp__agent-builder__generate_constraint_tests(
|
||||
```
|
||||
|
||||
**Response includes:**
|
||||
- `generated_count`: Number of tests generated
|
||||
- `tests`: List with id, test_name, description, confidence, test_code_preview
|
||||
- `next_step`: "Call approve_tests to approve, modify, or reject each test"
|
||||
- `output_file`: Where tests will be written when approved
|
||||
- `output_file`: Where to write tests (e.g., `exports/your_agent/tests/test_constraints.py`)
|
||||
- `file_header`: Imports, fixtures, and pytest setup to use at the top of the file
|
||||
- `test_template`: Format for test functions
|
||||
- `constraints_formatted`: The constraints to test
|
||||
- `test_guidelines`: Rules and best practices for writing tests
|
||||
- `instruction`: How to proceed
|
||||
|
||||
**USER APPROVAL REQUIRED**: Review generated tests and approve:
|
||||
**Write tests directly** using the provided guidelines:
|
||||
|
||||
```python
|
||||
# Review pending tests
|
||||
mcp__agent-builder__get_pending_tests(goal_id="your-goal-id")
|
||||
|
||||
# Approve tests (this writes them to files)
|
||||
mcp__agent-builder__approve_tests(
|
||||
goal_id="your-goal-id",
|
||||
approvals='[{"test_id": "test-1", "action": "approve"}, {"test_id": "test-2", "action": "approve"}]'
|
||||
# Write tests using the Write tool
|
||||
Write(
|
||||
file_path=result["output_file"],
|
||||
content=result["file_header"] + "\n\n" + your_test_code
|
||||
)
|
||||
```
|
||||
|
||||
**Approval actions:**
|
||||
- `approve` - Accept test as-is, write to file
|
||||
- `modify` - Accept with changes: `{"test_id": "...", "action": "modify", "modified_code": "..."}`
|
||||
- `reject` - Reject with reason: `{"test_id": "...", "action": "reject", "reason": "..."}`
|
||||
- `skip` - Skip for now
|
||||
### Step 3: Get Success Criteria Test Guidelines (Eval Stage)
|
||||
|
||||
### Step 3: Generate Success Criteria Tests (Eval Stage)
|
||||
|
||||
After agent is fully built, generate success criteria tests:
|
||||
After agent is fully built, get success criteria test guidelines:
|
||||
|
||||
```python
|
||||
# Generate success criteria tests via MCP tool
|
||||
mcp__agent-builder__generate_success_tests(
|
||||
# Get success criteria test guidelines via MCP tool
|
||||
result = mcp__agent-builder__generate_success_tests(
|
||||
goal_id="your-goal-id",
|
||||
goal_json='{"id": "goal-id", "name": "...", "success_criteria": [...]}',
|
||||
node_names="analyze_request,search_web,format_results",
|
||||
@@ -341,26 +374,28 @@ mcp__agent-builder__generate_success_tests(
|
||||
)
|
||||
```
|
||||
|
||||
**USER APPROVAL REQUIRED**: Same approval flow as constraint tests:
|
||||
**Write tests directly** using the provided guidelines:
|
||||
|
||||
```python
|
||||
# Review and approve
|
||||
mcp__agent-builder__get_pending_tests(goal_id="your-goal-id")
|
||||
mcp__agent-builder__approve_tests(
|
||||
goal_id="your-goal-id",
|
||||
approvals='[{"test_id": "...", "action": "approve"}]'
|
||||
# Write tests using the Write tool
|
||||
Write(
|
||||
file_path=result["output_file"],
|
||||
content=result["file_header"] + "\n\n" + your_test_code
|
||||
)
|
||||
```
|
||||
|
||||
### Step 4: Test Fixtures (conftest.py)
|
||||
|
||||
**conftest.py is auto-created** when you approve tests via `approve_tests`. It includes:
|
||||
- API key enforcement fixtures
|
||||
- `mock_mode` fixture
|
||||
- `credentials` fixture
|
||||
- `sample_inputs` fixture
|
||||
The `file_header` returned by the MCP tools includes proper imports and fixtures.
|
||||
You should also create a conftest.py file in the tests directory with shared fixtures:
|
||||
|
||||
You do NOT need to create conftest.py manually - the MCP tool handles this.
|
||||
```python
|
||||
# Create conftest.py with the conftest template
|
||||
Write(
|
||||
file_path="exports/your_agent/tests/conftest.py",
|
||||
content=conftest_content # Use PYTEST_CONFTEST_TEMPLATE format
|
||||
)
|
||||
```
|
||||
|
||||
### Step 5: Run Tests
|
||||
|
||||
@@ -739,6 +774,166 @@ This provides **immediate feedback** during development, catching issues early.
|
||||
|
||||
**Note:** All test patterns should include API key enforcement via conftest.py.
|
||||
|
||||
### ⚠️ CRITICAL: Framework Features You Must Know
|
||||
|
||||
#### OutputCleaner - Automatic I/O Cleaning (NEW!)
|
||||
|
||||
**The framework now automatically validates and cleans node outputs** using a fast LLM (Cerebras llama-3.3-70b) at edge traversal time. This prevents cascading failures from malformed output.
|
||||
|
||||
**What OutputCleaner does**:
|
||||
- ✅ Validates output matches next node's input schema
|
||||
- ✅ Detects JSON parsing trap (entire response in one key)
|
||||
- ✅ Cleans malformed output automatically (~200-500ms, ~$0.001 per cleaning)
|
||||
- ✅ Boosts success rates by 1.8-2.2x
|
||||
|
||||
**Impact on tests**: Tests should still use safe patterns because OutputCleaner may not catch all issues in test mode.
|
||||
|
||||
#### Safe Test Patterns (REQUIRED)
|
||||
|
||||
**❌ UNSAFE** (will cause test failures):
|
||||
```python
|
||||
# Direct key access - can crash!
|
||||
approval_decision = result.output["approval_decision"]
|
||||
assert approval_decision == "APPROVED"
|
||||
|
||||
# Nested access without checks
|
||||
category = result.output["analysis"]["category"]
|
||||
|
||||
# Assuming parsed JSON structure
|
||||
for issue in result.output["compliance_issues"]:
|
||||
...
|
||||
```
|
||||
|
||||
**✅ SAFE** (correct patterns):
|
||||
```python
|
||||
# 1. Safe dict access with .get()
|
||||
output = result.output or {}
|
||||
approval_decision = output.get("approval_decision", "UNKNOWN")
|
||||
assert "APPROVED" in approval_decision or approval_decision == "APPROVED"
|
||||
|
||||
# 2. Type checking before operations
|
||||
analysis = output.get("analysis", {})
|
||||
if isinstance(analysis, dict):
|
||||
category = analysis.get("category", "unknown")
|
||||
|
||||
# 3. Parse JSON from strings (the JSON parsing trap!)
|
||||
import json
|
||||
recommendation = output.get("recommendation", "{}")
|
||||
if isinstance(recommendation, str):
|
||||
try:
|
||||
parsed = json.loads(recommendation)
|
||||
if isinstance(parsed, dict):
|
||||
approval = parsed.get("approval_decision", "UNKNOWN")
|
||||
except json.JSONDecodeError:
|
||||
approval = "UNKNOWN"
|
||||
elif isinstance(recommendation, dict):
|
||||
approval = recommendation.get("approval_decision", "UNKNOWN")
|
||||
|
||||
# 4. Safe iteration with type check
|
||||
compliance_issues = output.get("compliance_issues", [])
|
||||
if isinstance(compliance_issues, list):
|
||||
for issue in compliance_issues:
|
||||
...
|
||||
```
|
||||
|
||||
#### Helper Functions for Safe Access
|
||||
|
||||
**Add to conftest.py**:
|
||||
```python
|
||||
import json
|
||||
import re
|
||||
|
||||
def _parse_json_from_output(result, key):
|
||||
"""Parse JSON from agent output (framework may store full LLM response as string)."""
|
||||
response_text = result.output.get(key, "")
|
||||
# Remove markdown code blocks if present
|
||||
json_text = re.sub(r'```json\s*|\s*```', '', response_text).strip()
|
||||
|
||||
try:
|
||||
return json.loads(json_text)
|
||||
except (json.JSONDecodeError, AttributeError, TypeError):
|
||||
return result.output.get(key)
|
||||
|
||||
def safe_get_nested(result, key_path, default=None):
|
||||
"""Safely get nested value from result.output."""
|
||||
output = result.output or {}
|
||||
current = output
|
||||
|
||||
for key in key_path:
|
||||
if isinstance(current, dict):
|
||||
current = current.get(key)
|
||||
elif isinstance(current, str):
|
||||
try:
|
||||
json_text = re.sub(r'```json\s*|\s*```', '', current).strip()
|
||||
parsed = json.loads(json_text)
|
||||
if isinstance(parsed, dict):
|
||||
current = parsed.get(key)
|
||||
else:
|
||||
return default
|
||||
except json.JSONDecodeError:
|
||||
return default
|
||||
else:
|
||||
return default
|
||||
|
||||
return current if current is not None else default
|
||||
|
||||
# Make available in tests
|
||||
pytest.parse_json_from_output = _parse_json_from_output
|
||||
pytest.safe_get_nested = safe_get_nested
|
||||
```
|
||||
|
||||
**Usage in tests**:
|
||||
```python
|
||||
# Use helper to parse JSON safely
|
||||
parsed = pytest.parse_json_from_output(result, "recommendation")
|
||||
if isinstance(parsed, dict):
|
||||
approval = parsed.get("approval_decision", "UNKNOWN")
|
||||
|
||||
# Safe nested access
|
||||
risk_score = pytest.safe_get_nested(result, ["analysis", "risk_score"], default=0.0)
|
||||
```
|
||||
|
||||
#### Test Count Guidance
|
||||
|
||||
**Generate 8-15 tests total, NOT 30+**
|
||||
|
||||
- ✅ 2-3 tests per success criterion
|
||||
- ✅ 1 happy path test
|
||||
- ✅ 1 boundary/edge case test
|
||||
- ✅ 1 error handling test (optional)
|
||||
|
||||
**Why fewer tests?**:
|
||||
- Each test requires real LLM call (~3 seconds, costs money)
|
||||
- 30 tests = 90 seconds, $0.30+ in costs
|
||||
- 12 tests = 36 seconds, $0.12 in costs
|
||||
- Focus on quality over quantity
|
||||
|
||||
#### ExecutionResult Fields (Important!)
|
||||
|
||||
**`result.success=True` means NO exception, NOT goal achieved**
|
||||
|
||||
```python
|
||||
# ❌ WRONG - assumes goal achieved
|
||||
assert result.success
|
||||
|
||||
# ✅ RIGHT - check success AND output
|
||||
assert result.success, f"Agent failed: {result.error}"
|
||||
output = result.output or {}
|
||||
approval = output.get("approval_decision")
|
||||
assert approval == "APPROVED", f"Expected APPROVED, got {approval}"
|
||||
```
|
||||
|
||||
**All ExecutionResult fields**:
|
||||
- `success: bool` - Execution completed without exception (NOT goal achieved!)
|
||||
- `output: dict` - Complete memory snapshot (may contain raw strings)
|
||||
- `error: str | None` - Error message if failed
|
||||
- `steps_executed: int` - Number of nodes executed
|
||||
- `total_tokens: int` - Cumulative token usage
|
||||
- `total_latency_ms: int` - Total execution time
|
||||
- `path: list[str]` - Node IDs traversed
|
||||
- `paused_at: str | None` - Node ID if HITL pause occurred
|
||||
- `session_state: dict` - State for resuming
|
||||
|
||||
### Happy Path Test
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
@@ -803,25 +998,24 @@ async def test_performance_latency(mock_mode):
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
### MCP Tool Enforcement
|
||||
### Testing Best Practices
|
||||
|
||||
| Don't | Do Instead |
|
||||
|-------|------------|
|
||||
| ❌ Write test files with Write tool | ✅ Use `generate_*_tests` + `approve_tests` |
|
||||
| ❌ Run pytest via Bash | ✅ Use `run_tests` MCP tool |
|
||||
| ❌ Debug tests with Bash pytest -vvs | ✅ Use `debug_test` MCP tool |
|
||||
| ❌ Edit test files directly | ✅ Use `approve_tests` with `action: "modify"` |
|
||||
| ❌ Write tests without getting guidelines first | ✅ Use `generate_*_tests` to get proper file_header and guidelines |
|
||||
| ❌ Run pytest via Bash | ✅ Use `run_tests` MCP tool for structured results |
|
||||
| ❌ Debug tests with Bash pytest -vvs | ✅ Use `debug_test` MCP tool for formatted output |
|
||||
| ❌ Check for tests with Glob | ✅ Use `list_tests` MCP tool |
|
||||
| ❌ Skip the file_header from guidelines | ✅ Always include the file_header for proper imports and fixtures |
|
||||
|
||||
### General Testing
|
||||
|
||||
| Don't | Do Instead |
|
||||
|-------|------------|
|
||||
| ❌ Auto-approve generated tests | ✅ Always require user approval via approve_tests |
|
||||
| ❌ Treat all failures the same | ✅ Use debug_test to categorize and iterate appropriately |
|
||||
| ❌ Rebuild entire agent for small bugs | ✅ Edit code directly, re-run tests |
|
||||
| ❌ Run tests without API key | ✅ Always set ANTHROPIC_API_KEY first |
|
||||
| ❌ Skip user review of generated tests | ✅ Show test code to user before approving |
|
||||
| ❌ Write tests without understanding the constraints/criteria | ✅ Read the formatted constraints/criteria from guidelines |
|
||||
|
||||
## Workflow Summary
|
||||
|
||||
@@ -829,11 +1023,11 @@ async def test_performance_latency(mock_mode):
|
||||
1. Check existing tests: list_tests(goal_id, agent_path)
|
||||
→ Scans exports/{agent}/tests/test_*.py
|
||||
↓
|
||||
2. Generate tests: generate_constraint_tests, generate_success_tests
|
||||
→ Returns pending tests (stored in memory)
|
||||
2. Get test guidelines: generate_constraint_tests, generate_success_tests
|
||||
→ Returns file_header, test_template, constraints/criteria, guidelines
|
||||
↓
|
||||
3. Review and approve: get_pending_tests → approve_tests → USER APPROVAL
|
||||
→ Writes approved tests to exports/{agent}/tests/test_*.py
|
||||
3. Write tests: Use Write tool with the provided guidelines
|
||||
→ Write tests to exports/{agent}/tests/test_*.py
|
||||
↓
|
||||
4. Run tests: run_tests(goal_id, agent_path)
|
||||
→ Executes: pytest exports/{agent}/tests/ -v
|
||||
@@ -861,14 +1055,15 @@ mcp__agent-builder__list_tests(
|
||||
agent_path="exports/your_agent"
|
||||
)
|
||||
|
||||
# Generate constraint tests (returns pending tests for approval)
|
||||
# Get constraint test guidelines (returns templates and guidelines, NOT generated tests)
|
||||
mcp__agent-builder__generate_constraint_tests(
|
||||
goal_id="your-goal-id",
|
||||
goal_json='{"id": "...", "constraints": [...]}',
|
||||
agent_path="exports/your_agent"
|
||||
)
|
||||
# Returns: output_file, file_header, test_template, constraints_formatted, test_guidelines
|
||||
|
||||
# Generate success criteria tests
|
||||
# Get success criteria test guidelines
|
||||
mcp__agent-builder__generate_success_tests(
|
||||
goal_id="your-goal-id",
|
||||
goal_json='{"id": "...", "success_criteria": [...]}',
|
||||
@@ -876,15 +1071,7 @@ mcp__agent-builder__generate_success_tests(
|
||||
tool_names="tool1,tool2",
|
||||
agent_path="exports/your_agent"
|
||||
)
|
||||
|
||||
# Review pending tests
|
||||
mcp__agent-builder__get_pending_tests(goal_id="your-goal-id")
|
||||
|
||||
# Approve tests → writes to Python files at exports/{agent}/tests/
|
||||
mcp__agent-builder__approve_tests(
|
||||
goal_id="your-goal-id",
|
||||
approvals='[{"test_id": "...", "action": "approve"}]'
|
||||
)
|
||||
# Returns: output_file, file_header, test_template, success_criteria_formatted, test_guidelines
|
||||
|
||||
# Run tests via pytest subprocess
|
||||
mcp__agent-builder__run_tests(
|
||||
|
||||
@@ -49,154 +49,155 @@ First, load the goal that was defined during the Goal stage:
|
||||
}
|
||||
```
|
||||
|
||||
## Step 2: Generate Constraint Tests
|
||||
## Step 2: Get Constraint Test Guidelines
|
||||
|
||||
During the Goal stage (or early Eval), generate tests for constraints:
|
||||
During the Goal stage (or early Eval), get test guidelines for constraints:
|
||||
|
||||
```python
|
||||
result = generate_constraint_tests(
|
||||
goal_id="youtube-research",
|
||||
goal_json='<goal JSON above>'
|
||||
goal_json='<goal JSON above>',
|
||||
agent_path="exports/youtube-research"
|
||||
)
|
||||
```
|
||||
|
||||
**Generated tests (awaiting approval):**
|
||||
**The result contains guidelines (not generated tests):**
|
||||
- `output_file`: Where to write tests
|
||||
- `file_header`: Imports and fixtures to use
|
||||
- `test_template`: Format for test functions
|
||||
- `constraints_formatted`: The constraints to test
|
||||
- `test_guidelines`: Rules for writing tests
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Generated Constraint Tests (2 tests) │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ [1/2] test_constraint_api_limits_respected │
|
||||
│ Constraint: api_limits │
|
||||
│ Confidence: 88% │
|
||||
│ │
|
||||
│ def test_constraint_api_limits_respected(agent): │
|
||||
│ """Verify API rate limits are not exceeded.""" │
|
||||
│ import time │
|
||||
│ for i in range(10): │
|
||||
│ result = agent.run({"topic": f"test_{i}"}) │
|
||||
│ time.sleep(0.1) │
|
||||
│ # Should complete without rate limit errors │
|
||||
│ assert "rate limit" not in str(result).lower() │
|
||||
│ │
|
||||
│ [a]pprove [r]eject [e]dit [s]kip │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ [2/2] test_constraint_content_safety_filter │
|
||||
│ Constraint: content_safety │
|
||||
│ Confidence: 91% │
|
||||
│ │
|
||||
│ def test_constraint_content_safety_filter(agent): │
|
||||
│ """Verify inappropriate content is filtered.""" │
|
||||
│ result = agent.run({"topic": "general topic"}) │
|
||||
│ for video in result.videos: │
|
||||
│ assert video.safe_for_work is True │
|
||||
│ assert video.age_restricted is False │
|
||||
│ │
|
||||
│ [a]pprove [r]eject [e]dit [s]kip │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
## Step 3: Write Constraint Tests
|
||||
|
||||
## Step 3: Approve Constraint Tests
|
||||
|
||||
Review and approve each test:
|
||||
Using the guidelines, write tests directly with the Write tool:
|
||||
|
||||
```python
|
||||
result = approve_tests(
|
||||
goal_id="youtube-research",
|
||||
approvals='[
|
||||
{"test_id": "test_constraint_api_001", "action": "approve"},
|
||||
{"test_id": "test_constraint_content_001", "action": "approve"}
|
||||
]'
|
||||
# Write constraint tests using the provided file_header and guidelines
|
||||
Write(
|
||||
file_path="exports/youtube-research/tests/test_constraints.py",
|
||||
content='''
|
||||
"""Constraint tests for youtube-research agent."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from exports.youtube_research import default_agent
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("MOCK_MODE"),
|
||||
reason="API key required for real testing."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constraint_api_limits_respected():
|
||||
"""Verify API rate limits are not exceeded."""
|
||||
import time
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
|
||||
for i in range(10):
|
||||
result = await default_agent.run({"topic": f"test_{i}"}, mock_mode=mock_mode)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Should complete without rate limit errors
|
||||
assert "rate limit" not in str(result).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constraint_content_safety_filter():
|
||||
"""Verify inappropriate content is filtered."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "general topic"}, mock_mode=mock_mode)
|
||||
|
||||
for video in result.videos:
|
||||
assert video.safe_for_work is True
|
||||
assert video.age_restricted is False
|
||||
'''
|
||||
)
|
||||
```
|
||||
|
||||
## Step 4: Generate Success Criteria Tests
|
||||
## Step 4: Get Success Criteria Test Guidelines
|
||||
|
||||
After the agent is built, generate success criteria tests:
|
||||
After the agent is built, get success criteria test guidelines:
|
||||
|
||||
```python
|
||||
result = generate_success_tests(
|
||||
goal_id="youtube-research",
|
||||
goal_json='<goal JSON>',
|
||||
node_names="search_node,filter_node,rank_node,format_node",
|
||||
tool_names="youtube_search,video_details,channel_info"
|
||||
tool_names="youtube_search,video_details,channel_info",
|
||||
agent_path="exports/youtube-research"
|
||||
)
|
||||
```
|
||||
|
||||
**Generated tests (awaiting approval):**
|
||||
## Step 5: Write Success Criteria Tests
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Generated Success Criteria Tests (4 tests) │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ [1/4] test_find_videos_happy_path │
|
||||
│ Criteria: find_videos │
|
||||
│ Confidence: 95% │
|
||||
│ │
|
||||
│ def test_find_videos_happy_path(agent): │
|
||||
│ """Test finding videos for a common topic.""" │
|
||||
│ result = agent.run({"topic": "machine learning"}) │
|
||||
│ assert result.success │
|
||||
│ assert 3 <= len(result.videos) <= 5 │
|
||||
│ assert all(v.title for v in result.videos) │
|
||||
│ assert all(v.video_id for v in result.videos) │
|
||||
│ │
|
||||
│ [a]pprove [r]eject [e]dit [s]kip │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ [2/4] test_find_videos_minimum_boundary │
|
||||
│ Criteria: find_videos │
|
||||
│ Confidence: 87% │
|
||||
│ │
|
||||
│ def test_find_videos_minimum_boundary(agent): │
|
||||
│ """Test at minimum threshold (3 videos).""" │
|
||||
│ result = agent.run({"topic": "niche topic xyz"}) │
|
||||
│ assert len(result.videos) >= 3 │
|
||||
│ │
|
||||
│ [a]pprove [r]eject [e]dit [s]kip │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ [3/4] test_relevance_score_threshold │
|
||||
│ Criteria: relevance │
|
||||
│ Confidence: 92% │
|
||||
│ │
|
||||
│ def test_relevance_score_threshold(agent): │
|
||||
│ """Test relevance scoring meets threshold.""" │
|
||||
│ result = agent.run({"topic": "python programming"}) │
|
||||
│ for video in result.videos: │
|
||||
│ assert video.relevance_score > 0.8 │
|
||||
│ │
|
||||
│ [a]pprove [r]eject [e]dit [s]kip │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ [4/4] test_find_videos_no_results_graceful │
|
||||
│ Criteria: find_videos │
|
||||
│ Confidence: 84% │
|
||||
│ │
|
||||
│ def test_find_videos_no_results_graceful(agent): │
|
||||
│ """Test graceful handling of no results.""" │
|
||||
│ result = agent.run({"topic": "xyznonexistent123"}) │
|
||||
│ # Should not crash, return empty or message │
|
||||
│ assert result.videos == [] or result.message │
|
||||
│ │
|
||||
│ [a]pprove [r]eject [e]dit [s]kip │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Step 5: Approve Success Criteria Tests
|
||||
Using the guidelines, write success criteria tests:
|
||||
|
||||
```python
|
||||
result = approve_tests(
|
||||
goal_id="youtube-research",
|
||||
approvals='[
|
||||
{"test_id": "test_success_001", "action": "approve"},
|
||||
{"test_id": "test_success_002", "action": "approve"},
|
||||
{"test_id": "test_success_003", "action": "approve"},
|
||||
{"test_id": "test_success_004", "action": "approve"}
|
||||
]'
|
||||
Write(
|
||||
file_path="exports/youtube-research/tests/test_success_criteria.py",
|
||||
content='''
|
||||
"""Success criteria tests for youtube-research agent."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from exports.youtube_research import default_agent
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("MOCK_MODE"),
|
||||
reason="API key required for real testing."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_videos_happy_path():
|
||||
"""Test finding videos for a common topic."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "machine learning"}, mock_mode=mock_mode)
|
||||
|
||||
assert result.success
|
||||
assert 3 <= len(result.videos) <= 5
|
||||
assert all(v.title for v in result.videos)
|
||||
assert all(v.video_id for v in result.videos)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_videos_minimum_boundary():
|
||||
"""Test at minimum threshold (3 videos)."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "niche topic xyz"}, mock_mode=mock_mode)
|
||||
|
||||
assert len(result.videos) >= 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relevance_score_threshold():
|
||||
"""Test relevance scoring meets threshold."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "python programming"}, mock_mode=mock_mode)
|
||||
|
||||
for video in result.videos:
|
||||
assert video.relevance_score > 0.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_videos_no_results_graceful():
|
||||
"""Test graceful handling of no results."""
|
||||
mock_mode = bool(os.environ.get("MOCK_MODE"))
|
||||
result = await default_agent.run({"topic": "xyznonexistent123"}, mock_mode=mock_mode)
|
||||
|
||||
# Should not crash, return empty or message
|
||||
assert result.videos == [] or result.message
|
||||
'''
|
||||
)
|
||||
```
|
||||
|
||||
## Step 6: Run All Tests
|
||||
|
||||
Execute all approved tests:
|
||||
Execute all tests:
|
||||
|
||||
```python
|
||||
result = run_tests(
|
||||
@@ -238,7 +239,8 @@ result = run_tests(
|
||||
```python
|
||||
result = debug_test(
|
||||
goal_id="youtube-research",
|
||||
test_id="test_success_004"
|
||||
test_name="test_find_videos_no_results_graceful",
|
||||
agent_path="exports/youtube-research"
|
||||
)
|
||||
```
|
||||
|
||||
@@ -335,14 +337,15 @@ result = run_tests(
|
||||
|
||||
## Summary
|
||||
|
||||
1. **Generated** constraint tests during Goal stage
|
||||
2. **Generated** success criteria tests during Eval stage
|
||||
3. **Approved** all tests with user review
|
||||
4. **Ran** tests in parallel
|
||||
5. **Debugged** the one failure
|
||||
6. **Categorized** as IMPLEMENTATION_ERROR
|
||||
7. **Fixed** the agent (not the goal)
|
||||
8. **Re-ran** Eval only (didn't restart full flow)
|
||||
9. **Passed** all tests
|
||||
1. **Got guidelines** for constraint tests during Goal stage
|
||||
2. **Wrote** constraint tests using Write tool
|
||||
3. **Got guidelines** for success criteria tests during Eval stage
|
||||
4. **Wrote** success criteria tests using Write tool
|
||||
5. **Ran** tests in parallel
|
||||
6. **Debugged** the one failure
|
||||
7. **Categorized** as IMPLEMENTATION_ERROR
|
||||
8. **Fixed** the agent (not the goal)
|
||||
9. **Re-ran** Eval only (didn't restart full flow)
|
||||
10. **Passed** all tests
|
||||
|
||||
The agent is now validated and ready for production use.
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
# Triage Issue Skill
|
||||
|
||||
Analyze a GitHub issue, verify claims against the codebase, and close invalid issues with a technical response.
|
||||
|
||||
## Trigger
|
||||
|
||||
User provides a GitHub issue URL or number, e.g.:
|
||||
- `/triage-issue 1970`
|
||||
- `/triage-issue https://github.com/adenhq/hive/issues/1970`
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Fetch Issue Details
|
||||
|
||||
```bash
|
||||
gh issue view <number> --repo adenhq/hive --json title,body,state,labels,author
|
||||
```
|
||||
|
||||
Extract:
|
||||
- Title
|
||||
- Body (the claim/bug report)
|
||||
- Current state
|
||||
- Labels
|
||||
- Author
|
||||
|
||||
If issue is already closed, inform user and stop.
|
||||
|
||||
### Step 2: Analyze the Claim
|
||||
|
||||
Read the issue body and identify:
|
||||
1. **The core claim** - What is the user asserting?
|
||||
2. **Technical specifics** - File paths, function names, code snippets mentioned
|
||||
3. **Expected behavior** - What do they think should happen?
|
||||
4. **Severity claimed** - Security issue? Bug? Feature request?
|
||||
|
||||
### Step 3: Investigate the Codebase
|
||||
|
||||
For each technical claim:
|
||||
1. Find the referenced code using Grep/Glob/Read
|
||||
2. Understand the actual implementation
|
||||
3. Check if the claim accurately describes the behavior
|
||||
4. Look for related tests, documentation, or design decisions
|
||||
|
||||
### Step 4: Evaluate Validity
|
||||
|
||||
Categorize the issue as one of:
|
||||
|
||||
| Category | Action |
|
||||
|----------|--------|
|
||||
| **Valid Bug** | Do NOT close. Inform user this is a real issue. |
|
||||
| **Valid Feature Request** | Do NOT close. Suggest labeling appropriately. |
|
||||
| **Misunderstanding** | Prepare technical explanation for why behavior is correct. |
|
||||
| **Fundamentally Flawed** | Prepare critique explaining the technical impossibility or design rationale. |
|
||||
| **Duplicate** | Find the original issue and prepare duplicate notice. |
|
||||
| **Incomplete** | Prepare request for more information. |
|
||||
|
||||
### Step 5: Draft Response
|
||||
|
||||
For issues to be closed, draft a response that:
|
||||
|
||||
1. **Acknowledges the concern** - Don't be dismissive
|
||||
2. **Explains the actual behavior** - With code references
|
||||
3. **Provides technical rationale** - Why it works this way
|
||||
4. **References industry standards** - If applicable
|
||||
5. **Offers alternatives** - If there's a better approach for the user
|
||||
|
||||
Use this template:
|
||||
|
||||
```markdown
|
||||
## Analysis
|
||||
|
||||
[Brief summary of what was investigated]
|
||||
|
||||
## Technical Details
|
||||
|
||||
[Explanation with code references]
|
||||
|
||||
## Why This Is Working As Designed
|
||||
|
||||
[Rationale]
|
||||
|
||||
## Recommendation
|
||||
|
||||
[What the user should do instead, if applicable]
|
||||
|
||||
---
|
||||
*This issue was reviewed and closed by the maintainers.*
|
||||
```
|
||||
|
||||
### Step 6: User Review
|
||||
|
||||
Present the draft to the user with:
|
||||
|
||||
```
|
||||
## Issue #<number>: <title>
|
||||
|
||||
**Claim:** <summary of claim>
|
||||
|
||||
**Finding:** <valid/invalid/misunderstanding/etc>
|
||||
|
||||
**Draft Response:**
|
||||
<the markdown response>
|
||||
|
||||
---
|
||||
Do you want me to post this comment and close the issue?
|
||||
```
|
||||
|
||||
Use AskUserQuestion with options:
|
||||
- "Post and close" - Post comment, close issue
|
||||
- "Edit response" - Let user modify the response
|
||||
- "Skip" - Don't take action
|
||||
|
||||
### Step 7: Execute Action
|
||||
|
||||
If user approves:
|
||||
|
||||
```bash
|
||||
# Post comment
|
||||
gh issue comment <number> --repo adenhq/hive --body "<response>"
|
||||
|
||||
# Close issue
|
||||
gh issue close <number> --repo adenhq/hive --reason "not planned"
|
||||
```
|
||||
|
||||
Report success with link to the issue.
|
||||
|
||||
## Important Guidelines
|
||||
|
||||
1. **Never close valid issues** - If there's any merit to the claim, don't close it
|
||||
2. **Be respectful** - The reporter took time to file the issue
|
||||
3. **Be technical** - Provide code references and evidence
|
||||
4. **Be educational** - Help them understand, don't just dismiss
|
||||
5. **Check twice** - Make sure you understand the code before declaring something invalid
|
||||
6. **Consider edge cases** - Maybe their environment reveals a real issue
|
||||
|
||||
## Example Critiques
|
||||
|
||||
### Security Misunderstanding
|
||||
> "The claim that secrets are exposed in plaintext misunderstands the encryption architecture. While `SecretStr` is used for logging protection, actual encryption is provided by Fernet (AES-128-CBC) at the storage layer. The code path is: serialize → encrypt → write. Only encrypted bytes touch disk."
|
||||
|
||||
### Impossible Request
|
||||
> "The requested feature would require [X] which violates [fundamental constraint]. This is not a limitation of our implementation but a fundamental property of [technology/protocol]."
|
||||
|
||||
### Already Handled
|
||||
> "This scenario is already handled by [code reference]. The reporter may be using an older version or misconfigured environment."
|
||||
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core",
|
||||
"env": {
|
||||
"PYTHONPATH": "../tools/src"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"command": "python",
|
||||
"args": ["mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"env": {
|
||||
"PYTHONPATH": "src"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/agent-workflow
|
||||
@@ -0,0 +1 @@
|
||||
../../.claude/skills/building-agents-construction
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/building-agents-core
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/building-agents-patterns
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/testing-agent
|
||||
@@ -0,0 +1,18 @@
|
||||
This project uses ruff for Python linting and formatting.
|
||||
|
||||
Rules:
|
||||
- Line length: 100 characters
|
||||
- Python target: 3.11+
|
||||
- Use double quotes for strings
|
||||
- Sort imports with isort (ruff I rules): stdlib, third-party, first-party (framework), local
|
||||
- Combine as-imports
|
||||
- Use type hints on all function signatures
|
||||
- Use `from __future__ import annotations` for modern type syntax
|
||||
- Raise exceptions with `from` in except blocks (B904)
|
||||
- No unused imports (F401), no unused variables (F841)
|
||||
- Prefer list/dict/set comprehensions over map/filter (C4)
|
||||
|
||||
Run `make lint` to auto-fix, `make check` to verify without modifying files.
|
||||
Run `make format` to apply ruff formatting.
|
||||
|
||||
The ruff config lives in core/pyproject.toml under [tool.ruff].
|
||||
@@ -11,6 +11,9 @@ indent_size = 2
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[*.py]
|
||||
indent_size = 4
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
|
||||
+124
@@ -0,0 +1,124 @@
|
||||
# Normalize line endings for all text files
|
||||
* text=auto
|
||||
|
||||
# Source code
|
||||
*.py text diff=python
|
||||
*.js text
|
||||
*.ts text
|
||||
*.jsx text
|
||||
*.tsx text
|
||||
*.json text
|
||||
*.yaml text
|
||||
*.yml text
|
||||
*.toml text
|
||||
*.ini text
|
||||
*.cfg text
|
||||
|
||||
# Shell scripts (must use LF)
|
||||
*.sh text eol=lf
|
||||
quickstart.sh text eol=lf
|
||||
|
||||
# PowerShell scripts (Windows-friendly)
|
||||
*.ps1 text eol=lf
|
||||
*.psm1 text eol=lf
|
||||
|
||||
# Windows batch files (must use CRLF)
|
||||
*.bat text eol=crlf
|
||||
*.cmd text eol=crlf
|
||||
|
||||
# Documentation
|
||||
*.md text
|
||||
*.txt text
|
||||
*.rst text
|
||||
*.tex text
|
||||
|
||||
# Configuration files
|
||||
.gitignore text
|
||||
.gitattributes text
|
||||
.editorconfig text
|
||||
Dockerfile text
|
||||
docker-compose.yml text
|
||||
requirements*.txt text
|
||||
pyproject.toml text
|
||||
setup.py text
|
||||
setup.cfg text
|
||||
MANIFEST.in text
|
||||
LICENSE text
|
||||
README* text
|
||||
CHANGELOG* text
|
||||
CONTRIBUTING* text
|
||||
CODE_OF_CONDUCT* text
|
||||
|
||||
# Web files
|
||||
*.html text
|
||||
*.css text
|
||||
*.scss text
|
||||
*.sass text
|
||||
|
||||
# Data files
|
||||
*.xml text
|
||||
*.csv text
|
||||
*.sql text
|
||||
|
||||
# Graphics (binary)
|
||||
*.png binary
|
||||
*.jpg binary
|
||||
*.jpeg binary
|
||||
*.gif binary
|
||||
*.ico binary
|
||||
*.svg binary
|
||||
*.eps binary
|
||||
*.bmp binary
|
||||
*.tif binary
|
||||
*.tiff binary
|
||||
|
||||
# Archives (binary)
|
||||
*.zip binary
|
||||
*.tar binary
|
||||
*.gz binary
|
||||
*.bz2 binary
|
||||
*.7z binary
|
||||
*.rar binary
|
||||
|
||||
# Python compiled (binary)
|
||||
*.pyc binary
|
||||
*.pyo binary
|
||||
*.pyd binary
|
||||
*.whl binary
|
||||
*.egg binary
|
||||
|
||||
# System libraries (binary)
|
||||
*.so binary
|
||||
*.dll binary
|
||||
*.dylib binary
|
||||
*.lib binary
|
||||
*.a binary
|
||||
|
||||
# Documents (binary)
|
||||
*.pdf binary
|
||||
*.doc binary
|
||||
*.docx binary
|
||||
*.ppt binary
|
||||
*.pptx binary
|
||||
*.xls binary
|
||||
*.xlsx binary
|
||||
|
||||
# Fonts (binary)
|
||||
*.ttf binary
|
||||
*.otf binary
|
||||
*.woff binary
|
||||
*.woff2 binary
|
||||
*.eot binary
|
||||
|
||||
# Audio/Video (binary)
|
||||
*.mp3 binary
|
||||
*.mp4 binary
|
||||
*.wav binary
|
||||
*.avi binary
|
||||
*.mov binary
|
||||
*.flv binary
|
||||
|
||||
# Database files (binary)
|
||||
*.db binary
|
||||
*.sqlite binary
|
||||
*.sqlite3 binary
|
||||
@@ -8,7 +8,6 @@
|
||||
/hive/ @adenhq/maintainers
|
||||
|
||||
# Infrastructure
|
||||
/docker-compose*.yml @adenhq/maintainers
|
||||
/.github/ @adenhq/maintainers
|
||||
|
||||
# Documentation
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
name: Auto-close duplicate issues
|
||||
description: Auto-closes issues that are duplicates of existing issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 */6 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
auto-close-duplicates:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Run auto-close-duplicates tests
|
||||
run: bun test scripts/auto-close-duplicates
|
||||
|
||||
- name: Auto-close duplicate issues
|
||||
run: bun run scripts/auto-close-duplicates.ts
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
STATSIG_API_KEY: ${{ secrets.STATSIG_API_KEY }}
|
||||
@@ -29,14 +29,22 @@ jobs:
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
- name: Run ruff
|
||||
- name: Ruff lint
|
||||
run: |
|
||||
cd core
|
||||
ruff check .
|
||||
ruff check core/
|
||||
ruff check tools/
|
||||
|
||||
- name: Ruff format
|
||||
run: |
|
||||
ruff format --check core/
|
||||
ruff format --check tools/
|
||||
|
||||
test:
|
||||
name: Test Python Framework
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -57,10 +65,31 @@ jobs:
|
||||
cd core
|
||||
pytest tests/ -v
|
||||
|
||||
test-tools:
|
||||
name: Test Tools
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies and run tests
|
||||
run: |
|
||||
cd tools
|
||||
uv sync --extra dev
|
||||
uv pip install --python .venv/bin/python -e ../core
|
||||
uv run --extra dev pytest tests/ -v
|
||||
|
||||
validate:
|
||||
name: Validate Agent Exports
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test]
|
||||
needs: [lint, test, test-tools]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -79,9 +108,31 @@ jobs:
|
||||
- name: Validate exported agents
|
||||
run: |
|
||||
# Check that agent exports have valid structure
|
||||
for agent_dir in exports/*/; do
|
||||
if [ ! -d "exports" ]; then
|
||||
echo "No exports/ directory found, skipping validation"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
shopt -s nullglob
|
||||
agent_dirs=(exports/*/)
|
||||
shopt -u nullglob
|
||||
|
||||
if [ ${#agent_dirs[@]} -eq 0 ]; then
|
||||
echo "No agent directories in exports/, skipping validation"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
validated=0
|
||||
for agent_dir in "${agent_dirs[@]}"; do
|
||||
if [ -f "$agent_dir/agent.json" ]; then
|
||||
echo "Validating $agent_dir"
|
||||
python -c "import json; json.load(open('$agent_dir/agent.json'))"
|
||||
validated=$((validated + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$validated" -eq 0 ]; then
|
||||
echo "No agent.json files found in exports/, skipping validation"
|
||||
else
|
||||
echo "Validated $validated agent(s)"
|
||||
fi
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
name: Issue Triage
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Triage and check for duplicates
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
allowed_non_write_users: "*"
|
||||
prompt: |
|
||||
Analyze this new issue and perform triage tasks.
|
||||
|
||||
Issue: #${{ github.event.issue.number }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
## Your Tasks:
|
||||
|
||||
### 1. Get issue details
|
||||
Use mcp__github__get_issue to get the full details of issue #${{ github.event.issue.number }}
|
||||
|
||||
### 2. Check for duplicates
|
||||
Search for similar existing issues using mcp__github__search_issues with relevant keywords from the issue title and body.
|
||||
|
||||
Criteria for duplicates:
|
||||
- Same bug or error being reported
|
||||
- Same feature request (even if worded differently)
|
||||
- Same question being asked
|
||||
- Issues describing the same root problem
|
||||
|
||||
If you find a duplicate:
|
||||
- Add a comment using EXACTLY this format (required for auto-close to work):
|
||||
"Found a possible duplicate of #<issue_number>: <brief explanation of why it's a duplicate>"
|
||||
- Do NOT apply the "duplicate" label yet (the auto-close script will add it after 12 hours if no objections)
|
||||
- Suggest the user react with a thumbs-down if they disagree
|
||||
|
||||
### 3. Check for Low-Quality / AI Spam
|
||||
Analyze the issue quality. We are receiving many low-effort, AI-generated spam issues.
|
||||
Flag the issue as INVALID if it matches these criteria:
|
||||
- **Vague/Generic**: Title is "Fix bug" or "Error" without specific context.
|
||||
- **Hallucinated**: Refers to files or features that do not exist in this repo.
|
||||
- **Template Filler**: Body contains "Insert description here" or unrelated gibberish.
|
||||
- **Low Effort**: No reproduction steps, no logs, only 1-2 sentences.
|
||||
|
||||
If identified as spam/low-quality:
|
||||
- Add the "invalid" label.
|
||||
- Add a comment:
|
||||
"This issue has been automatically flagged as low-quality or potentially AI-generated spam. It lacks specific details (logs, reproduction steps, file references) required for us to help. Please open a new issue following the template exactly if this is a legitimate request."
|
||||
- Do NOT proceed to other steps.
|
||||
|
||||
### 4. Check for invalid issues (General)
|
||||
If the issue is not spam but still lacks information:
|
||||
- Add the "invalid" label
|
||||
- Comment asking for clarification
|
||||
|
||||
### 5. Categorize with labels (if NOT a duplicate or spam)
|
||||
Apply appropriate labels based on the issue content. Use ONLY these labels:
|
||||
- bug: Something isn't working
|
||||
- enhancement: New feature or request
|
||||
- question: Further information is requested
|
||||
- documentation: Improvements or additions to documentation
|
||||
- good first issue: Good for newcomers (if issue is well-defined and small scope)
|
||||
- help wanted: Extra attention is needed (if issue needs community input)
|
||||
- backlog: Tracked for the future, but not currently planned or prioritized
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug" and "help wanted").
|
||||
|
||||
## Tools Available:
|
||||
- mcp__github__get_issue: Get issue details
|
||||
- mcp__github__search_issues: Search for similar issues
|
||||
- mcp__github__list_issues: List recent issues if needed
|
||||
- mcp__github__add_issue_comment: Add a comment
|
||||
- mcp__github__update_issue: Add labels
|
||||
- mcp__github__get_issue_comments: Get existing comments
|
||||
|
||||
Be thorough but efficient. Focus on accurate categorization and finding true duplicates.
|
||||
|
||||
claude_args: |
|
||||
--model claude-haiku-4-5-20251001
|
||||
--allowedTools "mcp__github__get_issue,mcp__github__search_issues,mcp__github__list_issues,mcp__github__add_issue_comment,mcp__github__update_issue,mcp__github__get_issue_comments"
|
||||
@@ -0,0 +1,204 @@
|
||||
name: PR Check Command
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
check-pr:
|
||||
# Only run on PR comments that start with /check
|
||||
if: github.event.issue.pull_request && startsWith(github.event.comment.body, '/check')
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
steps:
|
||||
- name: Check PR requirements
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
console.log(`Triggered by /check comment on PR #${prNumber}`);
|
||||
|
||||
// Fetch PR data
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
});
|
||||
|
||||
const prBody = pr.body || '';
|
||||
const prTitle = pr.title || '';
|
||||
const prAuthor = pr.user.login;
|
||||
const headSha = pr.head.sha;
|
||||
|
||||
// Create a check run in progress
|
||||
const { data: checkRun } = await github.rest.checks.create({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
name: 'check-requirements',
|
||||
head_sha: headSha,
|
||||
status: 'in_progress',
|
||||
started_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
// Extract issue numbers
|
||||
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
|
||||
const allText = `${prTitle} ${prBody}`;
|
||||
const matches = [...allText.matchAll(issuePattern)];
|
||||
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
|
||||
|
||||
console.log(`PR #${prNumber}:`);
|
||||
console.log(` Author: ${prAuthor}`);
|
||||
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
|
||||
|
||||
if (issueNumbers.length === 0) {
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**Missing:** No linked issue found.
|
||||
|
||||
**To fix:**
|
||||
1. Create or find an existing issue for this work
|
||||
2. Assign yourself to the issue
|
||||
3. Re-open this PR and add \`Fixes #123\` in the description
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
// Update check run to failure
|
||||
await github.rest.checks.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
check_run_id: checkRun.id,
|
||||
status: 'completed',
|
||||
conclusion: 'failure',
|
||||
completed_at: new Date().toISOString(),
|
||||
output: {
|
||||
title: 'Missing linked issue',
|
||||
summary: 'PR must reference an issue (e.g., `Fixes #123`)',
|
||||
},
|
||||
});
|
||||
|
||||
core.setFailed('PR must reference an issue');
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if PR author is assigned to any linked issue
|
||||
let issueWithAuthorAssigned = null;
|
||||
let issuesWithoutAuthor = [];
|
||||
|
||||
for (const issueNum of issueNumbers) {
|
||||
try {
|
||||
const { data: issue } = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNum,
|
||||
});
|
||||
|
||||
const assigneeLogins = (issue.assignees || []).map(a => a.login);
|
||||
if (assigneeLogins.includes(prAuthor)) {
|
||||
issueWithAuthorAssigned = issueNum;
|
||||
console.log(` Issue #${issueNum} has PR author ${prAuthor} as assignee`);
|
||||
break;
|
||||
} else {
|
||||
issuesWithoutAuthor.push({
|
||||
number: issueNum,
|
||||
assignees: assigneeLogins
|
||||
});
|
||||
console.log(` Issue #${issueNum} assignees: ${assigneeLogins.length > 0 ? assigneeLogins.join(', ') : 'none'}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(` Issue #${issueNum} not found`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!issueWithAuthorAssigned) {
|
||||
const issueList = issuesWithoutAuthor.map(i =>
|
||||
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
|
||||
).join(', ');
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**PR Author:** @${prAuthor}
|
||||
**Found issues:** ${issueList}
|
||||
**Problem:** The PR author must be assigned to the linked issue.
|
||||
|
||||
**To fix:**
|
||||
1. Assign yourself (@${prAuthor}) to one of the linked issues
|
||||
2. Re-open this PR
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
// Update check run to failure
|
||||
await github.rest.checks.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
check_run_id: checkRun.id,
|
||||
status: 'completed',
|
||||
conclusion: 'failure',
|
||||
completed_at: new Date().toISOString(),
|
||||
output: {
|
||||
title: 'PR author not assigned to issue',
|
||||
summary: `PR author @${prAuthor} must be assigned to one of the linked issues: ${issueList}`,
|
||||
},
|
||||
});
|
||||
|
||||
core.setFailed('PR author must be assigned to the linked issue');
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `✅ PR requirements met! Issue #${issueWithAuthorAssigned} has @${prAuthor} as assignee.`,
|
||||
});
|
||||
|
||||
// Update check run to success
|
||||
await github.rest.checks.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
check_run_id: checkRun.id,
|
||||
status: 'completed',
|
||||
conclusion: 'success',
|
||||
completed_at: new Date().toISOString(),
|
||||
output: {
|
||||
title: 'Requirements met',
|
||||
summary: `Issue #${issueWithAuthorAssigned} has @${prAuthor} as assignee.`,
|
||||
},
|
||||
});
|
||||
|
||||
console.log(`PR requirements met!`);
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
name: PR Requirements Backfill
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-all-open-prs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Check all open PRs
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { data: pullRequests } = await github.rest.pulls.list({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
state: 'open',
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
console.log(`Found ${pullRequests.length} open PRs`);
|
||||
|
||||
for (const pr of pullRequests) {
|
||||
const prNumber = pr.number;
|
||||
const prBody = pr.body || '';
|
||||
const prTitle = pr.title || '';
|
||||
const prAuthor = pr.user.login;
|
||||
|
||||
console.log(`\nChecking PR #${prNumber}: ${prTitle}`);
|
||||
|
||||
// Extract issue numbers from body and title
|
||||
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
|
||||
const allText = `${prTitle} ${prBody}`;
|
||||
const matches = [...allText.matchAll(issuePattern)];
|
||||
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
|
||||
|
||||
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
|
||||
|
||||
if (issueNumbers.length === 0) {
|
||||
console.log(` ❌ No linked issue - closing PR`);
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**Missing:** No linked issue found.
|
||||
|
||||
**To fix:**
|
||||
1. Create or find an existing issue for this work
|
||||
2. Assign yourself to the issue
|
||||
3. Re-open this PR and add \`Fixes #123\` in the description`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if any linked issue has the PR author as assignee
|
||||
let issueWithAuthorAssigned = null;
|
||||
let issuesWithoutAuthor = [];
|
||||
|
||||
for (const issueNum of issueNumbers) {
|
||||
try {
|
||||
const { data: issue } = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNum,
|
||||
});
|
||||
|
||||
const assigneeLogins = (issue.assignees || []).map(a => a.login);
|
||||
if (assigneeLogins.includes(prAuthor)) {
|
||||
issueWithAuthorAssigned = issueNum;
|
||||
break;
|
||||
} else {
|
||||
issuesWithoutAuthor.push({
|
||||
number: issueNum,
|
||||
assignees: assigneeLogins
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(` Issue #${issueNum} not found or inaccessible`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!issueWithAuthorAssigned) {
|
||||
const issueList = issuesWithoutAuthor.map(i =>
|
||||
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
|
||||
).join(', ');
|
||||
|
||||
console.log(` ❌ PR author not assigned to any linked issue - closing PR`);
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**PR Author:** @${prAuthor}
|
||||
**Found issues:** ${issueList}
|
||||
**Problem:** The PR author must be assigned to the linked issue.
|
||||
|
||||
**To fix:**
|
||||
1. Assign yourself (@${prAuthor}) to one of the linked issues
|
||||
2. Re-open this PR`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
} else {
|
||||
console.log(` ✅ PR requirements met! Issue #${issueWithAuthorAssigned} has ${prAuthor} as assignee.`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log('\nBackfill complete!');
|
||||
@@ -0,0 +1,189 @@
|
||||
name: PR Requirements Check
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited, synchronize]
|
||||
|
||||
jobs:
|
||||
check-requirements:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Check PR has linked issue with assignee
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const pr = context.payload.pull_request;
|
||||
const prNumber = pr.number;
|
||||
const prBody = pr.body || '';
|
||||
const prTitle = pr.title || '';
|
||||
const prLabels = (pr.labels || []).map(l => l.name);
|
||||
|
||||
// Allow micro-fix and documentation PRs without a linked issue
|
||||
const isMicroFix = prLabels.includes('micro-fix') || /micro-fix/i.test(prTitle);
|
||||
const isDocumentation = prLabels.includes('documentation') || /\bdocs?\b/i.test(prTitle);
|
||||
if (isMicroFix || isDocumentation) {
|
||||
const reason = isMicroFix ? 'micro-fix' : 'documentation';
|
||||
console.log(`PR #${prNumber} is a ${reason}, skipping issue requirement.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract issue numbers from body and title
|
||||
// Matches: fixes #123, closes #123, resolves #123, or plain #123
|
||||
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
|
||||
|
||||
const allText = `${prTitle} ${prBody}`;
|
||||
const matches = [...allText.matchAll(issuePattern)];
|
||||
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
|
||||
|
||||
console.log(`PR #${prNumber}:`);
|
||||
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
|
||||
|
||||
if (issueNumbers.length === 0) {
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**Missing:** No linked issue found.
|
||||
|
||||
**To fix:**
|
||||
1. Create or find an existing issue for this work
|
||||
2. Assign yourself to the issue
|
||||
3. Re-open this PR and add \`Fixes #123\` in the description
|
||||
|
||||
**Exception:** To bypass this requirement, you can:
|
||||
- Add the \`micro-fix\` label or include \`micro-fix\` in your PR title for trivial fixes
|
||||
- Add the \`documentation\` label or include \`doc\`/\`docs\` in your PR title for documentation changes
|
||||
|
||||
**Micro-fix requirements** (must meet ALL):
|
||||
| Qualifies | Disqualifies |
|
||||
|-----------|--------------|
|
||||
| < 20 lines changed | Any functional bug fix |
|
||||
| Typos & Documentation & Linting | Refactoring for "clean code" |
|
||||
| No logic/API/DB changes | New features (even tiny ones) |
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
const comments = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const botComment = comments.data.find(
|
||||
(c) => c.user.type === 'Bot' && c.body.includes('PR Closed - Requirements Not Met')
|
||||
);
|
||||
|
||||
if (!botComment) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
}
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
core.setFailed('PR must reference an issue');
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if any linked issue has the PR author as assignee
|
||||
const prAuthor = pr.user.login;
|
||||
let issueWithAuthorAssigned = null;
|
||||
let issuesWithoutAuthor = [];
|
||||
|
||||
for (const issueNum of issueNumbers) {
|
||||
try {
|
||||
const { data: issue } = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNum,
|
||||
});
|
||||
|
||||
const assigneeLogins = (issue.assignees || []).map(a => a.login);
|
||||
if (assigneeLogins.includes(prAuthor)) {
|
||||
issueWithAuthorAssigned = issueNum;
|
||||
console.log(` Issue #${issueNum} has PR author ${prAuthor} as assignee`);
|
||||
break;
|
||||
} else {
|
||||
issuesWithoutAuthor.push({
|
||||
number: issueNum,
|
||||
assignees: assigneeLogins
|
||||
});
|
||||
console.log(` Issue #${issueNum} assignees: ${assigneeLogins.length > 0 ? assigneeLogins.join(', ') : 'none'} (PR author: ${prAuthor})`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(` Issue #${issueNum} not found or inaccessible`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!issueWithAuthorAssigned) {
|
||||
const issueList = issuesWithoutAuthor.map(i =>
|
||||
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
|
||||
).join(', ');
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**PR Author:** @${prAuthor}
|
||||
**Found issues:** ${issueList}
|
||||
**Problem:** The PR author must be assigned to the linked issue.
|
||||
|
||||
**To fix:**
|
||||
1. Assign yourself (@${prAuthor}) to one of the linked issues
|
||||
2. Re-open this PR
|
||||
|
||||
**Exception:** To bypass this requirement, you can:
|
||||
- Add the \`micro-fix\` label or include \`micro-fix\` in your PR title for trivial fixes
|
||||
- Add the \`documentation\` label or include \`doc\`/\`docs\` in your PR title for documentation changes
|
||||
|
||||
**Micro-fix requirements** (must meet ALL):
|
||||
| Qualifies | Disqualifies |
|
||||
|-----------|--------------|
|
||||
| < 20 lines changed | Any functional bug fix |
|
||||
| Typos & Documentation & Linting | Refactoring for "clean code" |
|
||||
| No logic/API/DB changes | New features (even tiny ones) |
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
const comments = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const botComment = comments.data.find(
|
||||
(c) => c.user.type === 'Bot' && c.body.includes('PR Closed - Requirements Not Met')
|
||||
);
|
||||
|
||||
if (!botComment) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
}
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
core.setFailed('PR author must be assigned to the linked issue');
|
||||
} else {
|
||||
console.log(`PR requirements met! Issue #${issueWithAuthorAssigned} has ${prAuthor} as assignee.`);
|
||||
}
|
||||
+1
-1
@@ -69,4 +69,4 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.venv
|
||||
.venv
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "python",
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core",
|
||||
"env": {
|
||||
@@ -9,11 +9,11 @@
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"command": "python",
|
||||
"command": ".venv/bin/python",
|
||||
"args": ["mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"env": {
|
||||
"PYTHONPATH": "src"
|
||||
"PYTHONPATH": "src:../core"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.6
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff lint (core)
|
||||
args: [--fix]
|
||||
files: ^core/
|
||||
- id: ruff
|
||||
name: ruff lint (tools)
|
||||
args: [--fix]
|
||||
files: ^tools/
|
||||
- id: ruff-format
|
||||
name: ruff format (core)
|
||||
files: ^core/
|
||||
- id: ruff-format
|
||||
name: ruff format (tools)
|
||||
files: ^tools/
|
||||
@@ -0,0 +1 @@
|
||||
3.11
|
||||
Vendored
+7
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"charliermarsh.ruff",
|
||||
"editorconfig.editorconfig",
|
||||
"ms-python.python"
|
||||
]
|
||||
}
|
||||
+2
-1
@@ -25,8 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
### Removed
|
||||
- N/A
|
||||
|
||||
|
||||
### Fixed
|
||||
- N/A
|
||||
- tools: Fixed web_scrape tool attempting to parse non-HTML content (PDF, JSON) as HTML (#487)
|
||||
|
||||
### Security
|
||||
- N/A
|
||||
|
||||
+66
-21
@@ -1,34 +1,63 @@
|
||||
# Contributing to Aden Agent Framework
|
||||
|
||||
Thank you for your interest in contributing to the Aden Agent Framework! This document provides guidelines and information for contributors.
|
||||
Thank you for your interest in contributing to the Aden Agent Framework! This document provides guidelines and information for contributors. We’re especially looking for help building tools, integrations([check #2805](https://github.com/adenhq/hive/issues/2805)), and example agents for the framework. If you’re interested in extending its functionality, this is the perfect place to start.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
By participating in this project, you agree to abide by our [Code of Conduct](CODE_OF_CONDUCT.md).
|
||||
|
||||
## Issue Assignment Policy
|
||||
|
||||
To prevent duplicate work and respect contributors' time, we require issue assignment before submitting PRs.
|
||||
|
||||
### How to Claim an Issue
|
||||
|
||||
1. **Find an Issue:** Browse existing issues or create a new one
|
||||
2. **Claim It:** Leave a comment (e.g., *"I'd like to work on this!"*)
|
||||
3. **Wait for Assignment:** A maintainer will assign you within 24 hours. Issues with reproducible steps or proposals are prioritized.
|
||||
4. **Submit Your PR:** Once assigned, you're ready to contribute
|
||||
|
||||
> **Note:** PRs for unassigned issues may be delayed or closed if someone else was already assigned.
|
||||
|
||||
### Exceptions (No Assignment Needed)
|
||||
|
||||
You may submit PRs without prior assignment for:
|
||||
- **Documentation:** Fixing typos or clarifying instructions — add the `documentation` label or include `doc`/`docs` in your PR title to bypass the linked issue requirement
|
||||
- **Micro-fixes:** Add the `micro-fix` label or include `micro-fix` in your PR title to bypass the linked issue requirement. Micro-fixes must meet **all** qualification criteria:
|
||||
|
||||
| Qualifies | Disqualifies |
|
||||
|-----------|--------------|
|
||||
| < 20 lines changed | Any functional bug fix |
|
||||
| Typos & Documentation & Linting | Refactoring for "clean code" |
|
||||
| No logic/API/DB changes | New features (even tiny ones) |
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Fork the repository
|
||||
2. Clone your fork: `git clone https://github.com/YOUR_USERNAME/hive.git`
|
||||
3. Create a feature branch: `git checkout -b feature/your-feature-name`
|
||||
4. Make your changes
|
||||
5. Run tests: `PYTHONPATH=core:exports python -m pytest`
|
||||
5. Run checks and tests:
|
||||
```bash
|
||||
make check # Lint and format checks (ruff check + ruff format --check on core/ and tools/)
|
||||
make test # Core tests (cd core && pytest tests/ -v)
|
||||
```
|
||||
6. Commit your changes following our commit conventions
|
||||
7. Push to your fork and submit a Pull Request
|
||||
|
||||
## Development Setup
|
||||
|
||||
```bash
|
||||
# Install Python packages
|
||||
./scripts/setup-python.sh
|
||||
|
||||
# Verify installation
|
||||
python -c "import framework; import aden_tools; print('✓ Setup complete')"
|
||||
|
||||
# Install Claude Code skills (optional)
|
||||
# Install Python packages and verify setup
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
> **Windows Users:**
|
||||
> If you are on native Windows, it is recommended to use **WSL (Windows Subsystem for Linux)**.
|
||||
> Alternatively, make sure to run PowerShell or Git Bash with Python 3.11+ installed, and disable "App Execution Aliases" in Windows settings.
|
||||
|
||||
> **Tip:** Installing Claude Code skills is optional for running existing agents, but required if you plan to **build new agents**.
|
||||
|
||||
## Commit Convention
|
||||
|
||||
We follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
@@ -59,11 +88,12 @@ docs(readme): update installation instructions
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
1. Update documentation if needed
|
||||
2. Add tests for new functionality
|
||||
3. Ensure all tests pass
|
||||
4. Update the CHANGELOG.md if applicable
|
||||
5. Request review from maintainers
|
||||
1. **Get assigned to the issue first** (see [Issue Assignment Policy](#issue-assignment-policy))
|
||||
2. Update documentation if needed
|
||||
3. Add tests for new functionality
|
||||
4. Ensure `make check` and `make test` pass
|
||||
5. Update the CHANGELOG.md if applicable
|
||||
6. Request review from maintainers
|
||||
|
||||
### PR Title Format
|
||||
|
||||
@@ -75,7 +105,7 @@ feat(component): add new feature description
|
||||
## Project Structure
|
||||
|
||||
- `core/` - Core framework (agent runtime, graph executor, protocols)
|
||||
- `tools/` - MCP Tools Package (19 tools for agent capabilities)
|
||||
- `tools/` - MCP Tools Package (tools for agent capabilities)
|
||||
- `exports/` - Agent packages and examples
|
||||
- `docs/` - Documentation
|
||||
- `scripts/` - Build and utility scripts
|
||||
@@ -92,19 +122,34 @@ feat(component): add new feature description
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run all tests for the framework
|
||||
cd core && python -m pytest
|
||||
> **Note:** When testing agents in `exports/`, always set PYTHONPATH:
|
||||
>
|
||||
> ```bash
|
||||
> PYTHONPATH=core:exports python -m agent_name test
|
||||
> ```
|
||||
|
||||
# Run all tests for tools
|
||||
cd tools && python -m pytest
|
||||
```bash
|
||||
# Run lint and format checks (mirrors CI lint job)
|
||||
make check
|
||||
|
||||
# Run core framework tests (mirrors CI test job)
|
||||
make test
|
||||
|
||||
# Or run tests directly
|
||||
cd core && pytest tests/ -v
|
||||
|
||||
# Run tests for a specific agent
|
||||
PYTHONPATH=core:exports python -m agent_name test
|
||||
```
|
||||
|
||||
> **CI also validates** that all exported agent JSON files (`exports/*/agent.json`) are well-formed JSON. Ensure your agent exports are valid before submitting.
|
||||
|
||||
## Contributor License Agreement
|
||||
|
||||
By submitting a Pull Request, you agree that your contributions will be licensed under the Aden Agent Framework license.
|
||||
|
||||
## Questions?
|
||||
|
||||
Feel free to open an issue for questions or join our [Discord community](https://discord.com/invite/MXE49hrKDk).
|
||||
|
||||
Thank you for contributing!
|
||||
Thank you for contributing!
|
||||
+72
-161
@@ -23,8 +23,8 @@ Aden Agent Framework is a Python-based system for building goal-driven, self-imp
|
||||
| Package | Directory | Description | Tech Stack |
|
||||
| ------------- | ---------- | --------------------------------------- | ------------ |
|
||||
| **framework** | `/core` | Core runtime, graph executor, protocols | Python 3.11+ |
|
||||
| **tools** | `/tools` | 19 MCP tools for agent capabilities | Python 3.11+ |
|
||||
| **exports** | `/exports` | Agent packages and examples | Python 3.11+ |
|
||||
| **tools** | `/tools` | MCP tools for agent capabilities | Python 3.11+ |
|
||||
| **exports** | `/exports` | Agent packages (user-created, gitignored) | Python 3.11+ |
|
||||
| **skills** | `.claude` | Claude Code skills for building/testing | Markdown |
|
||||
|
||||
### Key Principles
|
||||
@@ -44,7 +44,7 @@ Aden Agent Framework is a Python-based system for building goal-driven, self-imp
|
||||
Ensure you have installed:
|
||||
|
||||
- **Python 3.11+** - [Download](https://www.python.org/downloads/) (3.12 or 3.13 recommended)
|
||||
- **pip** - Package installer for Python (comes with Python)
|
||||
- **uv** - Python package manager ([Install](https://docs.astral.sh/uv/getting-started/installation/))
|
||||
- **git** - Version control
|
||||
- **Claude Code** - [Install](https://docs.anthropic.com/claude/docs/claude-code) (optional, for using building skills)
|
||||
|
||||
@@ -52,7 +52,7 @@ Verify installation:
|
||||
|
||||
```bash
|
||||
python --version # Should be 3.11+
|
||||
pip --version # Should be latest
|
||||
uv --version # Should be latest
|
||||
git --version # Any recent version
|
||||
```
|
||||
|
||||
@@ -63,8 +63,8 @@ git --version # Any recent version
|
||||
git clone https://github.com/adenhq/hive.git
|
||||
cd hive
|
||||
|
||||
# 2. Run automated Python setup
|
||||
./scripts/setup-python.sh
|
||||
# 2. Run automated setup
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
The setup script performs these actions:
|
||||
@@ -99,10 +99,13 @@ Get API keys:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
This installs:
|
||||
This installs agent-related Claude Code skills:
|
||||
|
||||
- `/building-agents` - Build new goal-driven agents
|
||||
- `/testing-agent` - Test agents with evaluation framework
|
||||
- `/building-agents-core` - Fundamental agent concepts
|
||||
- `/building-agents-construction` - Step-by-step agent building
|
||||
- `/building-agents-patterns` - Best practices and design patterns
|
||||
- `/testing-agent` - Test and validate agents
|
||||
- `/agent-workflow` - End-to-end guided workflow
|
||||
|
||||
### Verify Setup
|
||||
|
||||
@@ -112,8 +115,8 @@ python -c "import framework; print('✓ framework OK')"
|
||||
python -c "import aden_tools; print('✓ aden_tools OK')"
|
||||
python -c "import litellm; print('✓ litellm OK')"
|
||||
|
||||
# Run an example agent
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent validate
|
||||
# Run an agent (after building one via /building-agents-construction)
|
||||
PYTHONPATH=core:exports python -m your_agent_name validate
|
||||
```
|
||||
|
||||
---
|
||||
@@ -125,70 +128,81 @@ hive/ # Repository root
|
||||
│
|
||||
├── .github/ # GitHub configuration
|
||||
│ ├── workflows/
|
||||
│ │ ├── ci.yml # Runs on every PR
|
||||
│ │ └── release.yml # Runs on tags
|
||||
│ │ ├── ci.yml # Lint, test, validate on every PR
|
||||
│ │ ├── release.yml # Runs on tags
|
||||
│ │ ├── pr-requirements.yml # PR requirement checks
|
||||
│ │ ├── pr-check-command.yml # PR check commands
|
||||
│ │ ├── claude-issue-triage.yml # Automated issue triage
|
||||
│ │ └── auto-close-duplicates.yml # Close duplicate issues
|
||||
│ ├── ISSUE_TEMPLATE/ # Bug report & feature request templates
|
||||
│ ├── PULL_REQUEST_TEMPLATE.md # PR description template
|
||||
│ └── CODEOWNERS # Auto-assign reviewers
|
||||
│
|
||||
├── .claude/ # Claude Code Skills
|
||||
│ └── skills/
|
||||
│ ├── building-agents/ # Skills for building agents
|
||||
│ │ ├── SKILL.md # Main skill definition
|
||||
│ │ ├── building-agents-core/
|
||||
│ │ ├── building-agents-patterns/
|
||||
│ │ └── building-agents-construction/
|
||||
│ └── skills/ # Skills for building
|
||||
│ ├── building-agents-core/
|
||||
| | ├── SKILL.md # Main skill definition
|
||||
│ | └── examples
|
||||
│ ├── building-agents-patterns/
|
||||
| | ├── SKILL.md
|
||||
│ | └── examples
|
||||
│ ├── building-agents-construction/
|
||||
| | ├── SKILL.md
|
||||
│ | └── examples
|
||||
│ ├── testing-agent/ # Skills for testing agents
|
||||
│ │ └── SKILL.md
|
||||
│ └── agent-workflow/ # Complete workflow orchestration
|
||||
│ │ ├── SKILL.md
|
||||
│ | └── examples
|
||||
│ └── agent-workflow/ # Complete workflow
|
||||
| ├── SKILL.md
|
||||
│ └── examples
|
||||
│
|
||||
├── core/ # CORE FRAMEWORK PACKAGE
|
||||
│ ├── framework/ # Main package code
|
||||
│ │ ├── runner/ # AgentRunner - loads and runs agents
|
||||
│ │ ├── executor/ # GraphExecutor - executes node graphs
|
||||
│ │ ├── protocols/ # Standard protocols (hooks, tracing, etc.)
|
||||
│ │ ├── builder/ # Agent builder utilities
|
||||
│ │ ├── credentials/ # Credential management
|
||||
│ │ ├── graph/ # GraphExecutor - executes node graphs
|
||||
│ │ ├── llm/ # LLM provider integrations (Anthropic, OpenAI, etc.)
|
||||
│ │ ├── memory/ # Memory systems (STM, LTM/RLM)
|
||||
│ │ ├── tools/ # Tool registry and management
|
||||
│ │ ├── mcp/ # MCP server integration
|
||||
│ │ ├── runner/ # AgentRunner - loads and runs agents
|
||||
│ │ ├── runtime/ # Runtime environment
|
||||
│ │ ├── schemas/ # Data schemas
|
||||
│ │ ├── storage/ # File-based persistence
|
||||
│ │ ├── testing/ # Testing utilities
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata and dependencies
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ ├── README.md # Framework documentation
|
||||
│ ├── MCP_INTEGRATION_GUIDE.md # MCP server integration guide
|
||||
│ └── docs/ # Protocol documentation
|
||||
│
|
||||
├── tools/ # TOOLS PACKAGE (19 MCP tools)
|
||||
├── tools/ # TOOLS PACKAGE (MCP tools)
|
||||
│ ├── src/
|
||||
│ │ └── aden_tools/
|
||||
│ │ ├── tools/ # Individual tool implementations
|
||||
│ │ │ ├── web_search_tool/
|
||||
│ │ │ ├── web_scrape_tool/
|
||||
│ │ │ ├── file_system_toolkits/
|
||||
│ │ │ └── ... # 19 tools total
|
||||
│ │ │ └── ... # Additional tools
|
||||
│ │ ├── mcp_server.py # HTTP MCP server
|
||||
│ │ └── __init__.py
|
||||
│ ├── pyproject.toml # Package metadata
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ └── README.md # Tools documentation
|
||||
│
|
||||
├── exports/ # AGENT PACKAGES
|
||||
│ ├── support_ticket_agent/ # Example: Support ticket handler
|
||||
│ ├── market_research_agent/ # Example: Market research
|
||||
│ ├── outbound_sales_agent/ # Example: Sales outreach
|
||||
│ ├── personal_assistant_agent/ # Example: Personal assistant
|
||||
│ └── ... # More agent examples
|
||||
├── exports/ # AGENT PACKAGES (user-created, gitignored)
|
||||
│ └── your_agent_name/ # Created via /building-agents-construction
|
||||
│
|
||||
├── docs/ # Documentation
|
||||
│ ├── getting-started.md # Quick start guide
|
||||
│ ├── configuration.md # Configuration reference
|
||||
│ ├── architecture.md # System architecture
|
||||
│ └── articles/ # Technical articles
|
||||
│ ├── architecture/ # System architecture
|
||||
│ ├── articles/ # Technical articles
|
||||
│ ├── quizzes/ # Developer quizzes
|
||||
│ └── i18n/ # Translations
|
||||
│
|
||||
├── scripts/ # Build & utility scripts
|
||||
│ ├── setup-python.sh # Python environment setup
|
||||
│ └── setup.sh # Legacy setup script
|
||||
│
|
||||
├── quickstart.sh # Install Claude Code skills
|
||||
├── quickstart.sh # Interactive setup wizard
|
||||
├── ENVIRONMENT_SETUP.md # Complete Python setup guide
|
||||
├── README.md # Project overview
|
||||
├── DEVELOPER.md # This file
|
||||
@@ -213,7 +227,7 @@ The fastest way to build agents is using the Claude Code skills:
|
||||
./quickstart.sh
|
||||
|
||||
# Build a new agent
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Test the agent
|
||||
claude> /testing-agent
|
||||
@@ -224,7 +238,7 @@ claude> /testing-agent
|
||||
1. **Define Your Goal**
|
||||
|
||||
```
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
Enter goal: "Build an agent that processes customer support tickets"
|
||||
```
|
||||
|
||||
@@ -267,7 +281,7 @@ If you prefer to build agents manually:
|
||||
{
|
||||
"node_id": "analyze",
|
||||
"name": "Analyze Ticket",
|
||||
"node_type": "llm",
|
||||
"node_type": "llm_generate",
|
||||
"system_prompt": "Analyze this support ticket...",
|
||||
"input_keys": ["ticket_content"],
|
||||
"output_keys": ["category", "priority"]
|
||||
@@ -365,7 +379,7 @@ def test_ticket_categorization():
|
||||
- **PEP 8** - Follow Python style guide
|
||||
- **Type hints** - Use for function signatures and class attributes
|
||||
- **Docstrings** - Document classes and public functions
|
||||
- **Black** - Code formatter (run with `black .`)
|
||||
- **Ruff** - Linter and formatter (run with `make check`)
|
||||
|
||||
```python
|
||||
# Good
|
||||
@@ -499,8 +513,8 @@ chore(deps): update React to 18.2.0
|
||||
|
||||
1. Create a feature branch from `main`
|
||||
2. Make your changes with clear commits
|
||||
3. Run tests locally: `npm run test`
|
||||
4. Run linting: `npm run lint`
|
||||
3. Run tests locally: `make test`
|
||||
4. Run linting: `make check`
|
||||
5. Push and create a PR
|
||||
6. Fill out the PR template
|
||||
7. Request review from CODEOWNERS
|
||||
@@ -509,66 +523,6 @@ chore(deps): update React to 18.2.0
|
||||
|
||||
---
|
||||
|
||||
## Debugging
|
||||
|
||||
### Frontend Debugging
|
||||
|
||||
**React Developer Tools:**
|
||||
|
||||
1. Install the [React DevTools browser extension](https://react.dev/learn/react-developer-tools)
|
||||
2. Open browser DevTools → React tab
|
||||
3. Inspect component tree, props, state, and hooks
|
||||
|
||||
**VS Code Debugging:**
|
||||
|
||||
1. Add Chrome debug configuration to `.vscode/launch.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"name": "Debug Frontend",
|
||||
"url": "http://localhost:3000",
|
||||
"webRoot": "${workspaceFolder}/honeycomb/src"
|
||||
}
|
||||
```
|
||||
|
||||
2. Start the dev server: `npm run dev -w honeycomb`
|
||||
3. Press F5 in VS Code
|
||||
|
||||
### Backend Debugging
|
||||
|
||||
**VS Code Debugging:**
|
||||
|
||||
1. Add Node debug configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"name": "Debug Backend",
|
||||
"runtimeExecutable": "npm",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"cwd": "${workspaceFolder}/hive",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
```
|
||||
|
||||
2. Set breakpoints in your code
|
||||
3. Press F5 to start debugging
|
||||
|
||||
**Logging:**
|
||||
|
||||
```typescript
|
||||
import { logger } from "../utils/logger";
|
||||
|
||||
// Add debug logs
|
||||
logger.debug("Processing request", {
|
||||
userId: req.user.id,
|
||||
body: req.body,
|
||||
});
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Tasks
|
||||
@@ -578,28 +532,24 @@ logger.debug("Processing request", {
|
||||
```bash
|
||||
# Add to core framework
|
||||
cd core
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
uv add <package>
|
||||
|
||||
# Add to tools package
|
||||
cd tools
|
||||
pip install <package>
|
||||
# Then add to requirements.txt or pyproject.toml
|
||||
|
||||
# Reinstall in editable mode
|
||||
pip install -e .
|
||||
uv add <package>
|
||||
```
|
||||
|
||||
### Creating a New Agent
|
||||
|
||||
```bash
|
||||
# Option 1: Use Claude Code skill (recommended)
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Option 2: Copy from example
|
||||
cp -r exports/support_ticket_agent exports/my_new_agent
|
||||
# Option 2: Create manually
|
||||
# Note: exports/ is initially empty (gitignored). Create your agent directory:
|
||||
mkdir -p exports/my_new_agent
|
||||
cd exports/my_new_agent
|
||||
# Edit agent.json, tools.py, README.md
|
||||
# Create agent.json, tools.py, README.md (see Agent Package Structure below)
|
||||
|
||||
# Option 3: Use the agent builder MCP tools (advanced)
|
||||
# See core/MCP_BUILDER_TOOLS_GUIDE.md
|
||||
@@ -708,61 +658,22 @@ kill -9 <PID>
|
||||
# Or change ports in config.yaml and regenerate
|
||||
```
|
||||
|
||||
### Node Modules Issues
|
||||
|
||||
```bash
|
||||
# Clean everything and reinstall
|
||||
npm run clean
|
||||
rm -rf node_modules package-lock.json
|
||||
npm install
|
||||
```
|
||||
|
||||
### Docker Issues
|
||||
|
||||
```bash
|
||||
# Reset Docker state
|
||||
docker compose down -v
|
||||
docker system prune -f
|
||||
docker compose build --no-cache
|
||||
docker compose up
|
||||
```
|
||||
|
||||
### TypeScript Errors After Pull
|
||||
|
||||
```bash
|
||||
# Rebuild TypeScript
|
||||
npm run build
|
||||
|
||||
# Or restart TS server in VS Code
|
||||
# Cmd/Ctrl + Shift + P → "TypeScript: Restart TS Server"
|
||||
```
|
||||
|
||||
### Environment Variables Not Loading
|
||||
|
||||
```bash
|
||||
# Regenerate from config.yaml
|
||||
npm run generate:env
|
||||
|
||||
# Verify files exist
|
||||
# Verify .env file exists at project root
|
||||
cat .env
|
||||
cat honeycomb/.env
|
||||
cat hive/.env
|
||||
|
||||
# Restart dev servers after changing env
|
||||
# Or check shell environment
|
||||
echo $ANTHROPIC_API_KEY
|
||||
|
||||
# Create .env if needed
|
||||
# Then add your API keys
|
||||
```
|
||||
|
||||
### Tests Failing
|
||||
|
||||
```bash
|
||||
# Run with verbose output
|
||||
npm run test -w honeycomb -- --reporter=verbose
|
||||
|
||||
# Run single test file
|
||||
npm run test -w honeycomb -- src/components/Button.test.tsx
|
||||
|
||||
# Clear test cache
|
||||
npm run test -w honeycomb -- --clearCache
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
+284
-53
@@ -6,9 +6,13 @@ Complete setup guide for building and running goal-driven agents with the Aden A
|
||||
|
||||
```bash
|
||||
# Run the automated setup script
|
||||
./scripts/setup-python.sh
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
> **Note for Windows Users:**
|
||||
> Running the setup script on native Windows shells (PowerShell / Git Bash) may sometimes fail due to Python App Execution Aliases.
|
||||
> It is **strongly recommended to use WSL (Windows Subsystem for Linux)** for a smoother setup experience.
|
||||
|
||||
This will:
|
||||
|
||||
- Check Python version (requires 3.11+)
|
||||
@@ -17,6 +21,63 @@ This will:
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
## Quick Setup (Windows – PowerShell)
|
||||
|
||||
Windows users can use the native PowerShell setup script.
|
||||
|
||||
Before running the script, allow script execution for the current session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
Run setup from the project root:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
This will:
|
||||
|
||||
- Check Python version (requires 3.11+)
|
||||
- Create a local `.venv` virtual environment
|
||||
- Install the core framework package (`framework`)
|
||||
- Install the tools package (`aden_tools`)
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
After setup, activate the virtual environment:
|
||||
|
||||
```powershell
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Set `PYTHONPATH` (required in every new PowerShell session):
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
```
|
||||
|
||||
## Alpine Linux Setup
|
||||
|
||||
If you are using Alpine Linux (e.g., inside a Docker container), you must install system dependencies and use a virtual environment before running the setup script:
|
||||
|
||||
1. Install System Dependencies:
|
||||
```bash
|
||||
apk update
|
||||
apk add bash git python3 py3-pip nodejs npm curl build-base python3-dev linux-headers libffi-dev
|
||||
```
|
||||
2. Set up Virtual Environment (Required for Python 3.12+):
|
||||
```
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install --upgrade pip setuptools wheel
|
||||
```
|
||||
3. Run the Quickstart Script:
|
||||
```
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
## Manual Setup (Alternative)
|
||||
|
||||
If you prefer to set up manually or the script fails:
|
||||
@@ -50,6 +111,9 @@ python -c "import aden_tools; print('✓ aden_tools OK')"
|
||||
python -c "import litellm; print('✓ litellm OK')"
|
||||
```
|
||||
|
||||
> **Windows Tip:**
|
||||
> On Windows, if the verification commands fail, ensure you are running them in **WSL** or after **disabling Python App Execution Aliases** in Windows Settings → Apps → App Execution Aliases.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Python Version
|
||||
@@ -63,6 +127,7 @@ python -c "import litellm; print('✓ litellm OK')"
|
||||
- pip (latest version)
|
||||
- 2GB+ RAM
|
||||
- Internet connection (for LLM API calls)
|
||||
- For Windows users: WSL 2 is recommended for full compatibility.
|
||||
|
||||
### API Keys (Optional)
|
||||
|
||||
@@ -72,86 +137,173 @@ For running agents with real LLMs:
|
||||
export ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
$env:ANTHROPIC_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
## Running Agents
|
||||
|
||||
All agent commands must be run from the project root with `PYTHONPATH` set:
|
||||
|
||||
```bash
|
||||
# From /home/timothy/oss/hive/ directory
|
||||
# From /hive/ directory
|
||||
PYTHONPATH=core:exports python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m agent_name COMMAND
|
||||
```
|
||||
|
||||
### Example: Support Ticket Agent
|
||||
|
||||
```bash
|
||||
# Validate agent structure
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent validate
|
||||
PYTHONPATH=core:exports python -m your_agent_name validate
|
||||
|
||||
# Show agent information
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent info
|
||||
PYTHONPATH=core:exports python -m your_agent_name info
|
||||
|
||||
# Run agent with input
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent run --input '{
|
||||
"ticket_content": "My login is broken. Error 401.",
|
||||
"customer_id": "CUST-123",
|
||||
"ticket_id": "TKT-456"
|
||||
PYTHONPATH=core:exports python -m your_agent_name run --input '{
|
||||
"task": "Your input here"
|
||||
}'
|
||||
|
||||
# Run in mock mode (no LLM calls)
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent run --mock --input '{...}'
|
||||
PYTHONPATH=core:exports python -m your_agent_name run --mock --input '{...}'
|
||||
```
|
||||
|
||||
### Example: Other Agents
|
||||
## Building New Agents and Run Flow
|
||||
|
||||
```bash
|
||||
# Market Research Agent
|
||||
PYTHONPATH=core:exports python -m market_research_agent info
|
||||
Build and run an agent using Claude Code CLI with the agent building skills:
|
||||
|
||||
# Outbound Sales Agent
|
||||
PYTHONPATH=core:exports python -m outbound_sales_agent validate
|
||||
|
||||
# Personal Assistant Agent
|
||||
PYTHONPATH=core:exports python -m personal_assistant_agent run --input '{...}'
|
||||
```
|
||||
|
||||
## Building New Agents
|
||||
|
||||
Use Claude Code CLI with the agent building skills:
|
||||
|
||||
### 1. Install Skills (One-time)
|
||||
### 1. Install Claude Skills (One-time)
|
||||
|
||||
```bash
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
This installs:
|
||||
This verifies agent-related Claude Code skills are available:
|
||||
|
||||
- `/building-agents` - Build new agents
|
||||
- `/testing-agent` - Test agents
|
||||
- `/building-agents-construction` - Step-by-step build guide
|
||||
- `/building-agents-core` - Fundamental concepts
|
||||
- `/building-agents-patterns` - Best practices
|
||||
- `/testing-agent` - Test and validate agents
|
||||
- `/agent-workflow` - Complete workflow
|
||||
|
||||
### 2. Build an Agent
|
||||
|
||||
```
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
```
|
||||
|
||||
Follow the prompts to:
|
||||
|
||||
1. Define your agent's goal
|
||||
2. Design the workflow nodes
|
||||
3. Connect edges
|
||||
4. Generate the agent package
|
||||
3. Connect nodes with edges
|
||||
4. Generate the agent package under `exports/`
|
||||
|
||||
### 3. Test Your Agent
|
||||
This step creates the initial agent structure required for further development.
|
||||
|
||||
### 3. Define Agent Logic
|
||||
|
||||
```
|
||||
claude> /building-agents-core
|
||||
```
|
||||
|
||||
Follow the prompts to:
|
||||
|
||||
1. Understand the agent architecture and file structure
|
||||
2. Define the agent's goal, success criteria, and constraints
|
||||
3. Learn node types (LLM, tool-use, router, function)
|
||||
4. Discover and validate available tools before use
|
||||
|
||||
This step establishes the core concepts and rules needed before building an agent.
|
||||
|
||||
### 4. Apply Agent Patterns
|
||||
|
||||
```
|
||||
claude> /building-agents-patterns
|
||||
```
|
||||
|
||||
Follow the prompts to:
|
||||
|
||||
1. Apply best-practice agent design patterns
|
||||
2. Add pause/resume flows for multi-turn interactions
|
||||
3. Improve robustness with routing, fallbacks, and retries
|
||||
4. Avoid common anti-patterns during agent construction
|
||||
|
||||
This step helps optimize agent design before final testing.
|
||||
|
||||
### 5. Test Your Agent
|
||||
|
||||
```
|
||||
claude> /testing-agent
|
||||
```
|
||||
Follow the prompts to:
|
||||
|
||||
Creates comprehensive test suites for your agent.
|
||||
1. Generate test guidelines for constraints and success criteria
|
||||
2. Write agent tests directly under `exports/{agent}/tests/`
|
||||
3. Run goal-based evaluation tests
|
||||
4. Debug failing tests and iterate on agent improvements
|
||||
|
||||
This step verifies that the agent meets its goals before production use.
|
||||
|
||||
### 6. Agent Development Workflow (End-to-End)
|
||||
|
||||
```
|
||||
claude> /agent-workflow
|
||||
```
|
||||
|
||||
Follow the guided flow to:
|
||||
|
||||
1. Understand core agent concepts (optional)
|
||||
2. Build the agent structure step by step
|
||||
3. Apply best-practice design patterns (optional)
|
||||
4. Test and validate the agent against its goals
|
||||
|
||||
This workflow orchestrates all agent-building skills to take you from idea → production-ready agent.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "externally-managed-environment" error (PEP 668)
|
||||
|
||||
**Cause:** Python 3.12+ on macOS/Homebrew, WSL, or some Linux distros prevents system-wide pip installs.
|
||||
|
||||
**Solution:** Create and use a virtual environment:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
python3 -m venv .venv
|
||||
|
||||
# Activate it
|
||||
source .venv/bin/activate # macOS/Linux
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Then run setup
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Always activate the venv before running agents:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
PYTHONPATH=core:exports python -m your_agent_name demo
|
||||
```
|
||||
|
||||
### PowerShell: “running scripts is disabled on this system”
|
||||
|
||||
Run once per session:
|
||||
|
||||
```powershell
|
||||
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'framework'"
|
||||
|
||||
**Solution:** Install the core package:
|
||||
@@ -171,7 +323,13 @@ cd tools && pip install -e .
|
||||
Or run the setup script:
|
||||
|
||||
```bash
|
||||
./scripts/setup-python.sh
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'openai.\_models'"
|
||||
@@ -184,14 +342,23 @@ Or run the setup script:
|
||||
pip install --upgrade "openai>=1.0.0"
|
||||
```
|
||||
|
||||
### "No module named 'support_ticket_agent'"
|
||||
### "No module named 'your_agent_name'"
|
||||
|
||||
**Cause:** Not running from project root or missing PYTHONPATH
|
||||
**Cause:** Not running from project root, missing PYTHONPATH, or agent not yet created
|
||||
|
||||
**Solution:** Ensure you're in `/home/timothy/oss/hive/` and use:
|
||||
**Solution:** Ensure you're in `/hive/` and use:
|
||||
|
||||
Linux/macOS:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent validate
|
||||
PYTHONPATH=core:exports python -m your_agent_name validate
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
$env:PYTHONPATH="core;exports"
|
||||
python -m support_ticket_agent validate
|
||||
```
|
||||
|
||||
### Agent imports fail with "broken installation"
|
||||
@@ -205,8 +372,13 @@ PYTHONPATH=core:exports python -m support_ticket_agent validate
|
||||
pip uninstall -y framework tools
|
||||
|
||||
# Reinstall correctly
|
||||
cd /home/timothy/oss/hive
|
||||
./scripts/setup-python.sh
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
@@ -217,22 +389,75 @@ The Hive framework consists of three Python packages:
|
||||
hive/
|
||||
├── core/ # Core framework (runtime, graph executor, LLM providers)
|
||||
│ ├── framework/
|
||||
│ ├── pyproject.toml
|
||||
│ └── requirements.txt
|
||||
│ ├── .venv/ # Created by quickstart.sh
|
||||
│ └── pyproject.toml
|
||||
│
|
||||
├── tools/ # Tools and MCP servers
|
||||
│ ├── src/
|
||||
│ │ └── aden_tools/ # Actual package location
|
||||
│ ├── pyproject.toml
|
||||
│ └── README.md
|
||||
│ ├── .venv/ # Created by quickstart.sh
|
||||
│ └── pyproject.toml
|
||||
│
|
||||
└── exports/ # Agent packages (your agents go here)
|
||||
├── support_ticket_agent/
|
||||
├── market_research_agent/
|
||||
├── outbound_sales_agent/
|
||||
└── personal_assistant_agent/
|
||||
└── exports/ # Agent packages (user-created, gitignored)
|
||||
└── your_agent_name/ # Created via /building-agents-construction
|
||||
```
|
||||
|
||||
## Separate Virtual Environments
|
||||
|
||||
The project uses **separate virtual environments** for `core` and `tools` packages to:
|
||||
|
||||
- Isolate dependencies and avoid conflicts
|
||||
- Allow independent development and testing of each package
|
||||
- Enable MCP servers to run with their specific dependencies
|
||||
|
||||
### How It Works
|
||||
|
||||
When you run `./quickstart.sh` or `uv sync` in each directory:
|
||||
|
||||
1. **core/.venv/** - Contains the `framework` package and its dependencies (anthropic, litellm, mcp, etc.)
|
||||
2. **tools/.venv/** - Contains the `aden_tools` package and its dependencies (beautifulsoup4, pandas, etc.)
|
||||
|
||||
### Cross-Package Imports
|
||||
|
||||
The `core` and `tools` packages are **intentionally independent**:
|
||||
|
||||
- **No cross-imports**: `framework` does not import `aden_tools` directly, and vice versa
|
||||
- **Communication via MCP**: Tools are exposed to agents through MCP servers, not direct Python imports
|
||||
- **Runtime integration**: The agent runner loads tools via the MCP protocol at runtime
|
||||
|
||||
If you need to use both packages in a single script (e.g., for testing), you have two options:
|
||||
|
||||
```bash
|
||||
# Option 1: Install both in a shared environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e core/ -e tools/
|
||||
|
||||
# Option 2: Use PYTHONPATH (for quick testing)
|
||||
PYTHONPATH=core:tools/src python your_script.py
|
||||
```
|
||||
|
||||
### MCP Server Configuration
|
||||
|
||||
The `.mcp.json` at project root configures MCP servers to use their respective virtual environments:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "core/.venv/bin/python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"]
|
||||
},
|
||||
"tools": {
|
||||
"command": "tools/.venv/bin/python",
|
||||
"args": ["-m", "aden_tools.mcp_server", "--stdio"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This ensures each MCP server runs with its correct dependencies.
|
||||
|
||||
### Why PYTHONPATH is Required
|
||||
|
||||
The packages are installed in **editable mode** (`pip install -e`), which means:
|
||||
@@ -251,20 +476,26 @@ This design allows agents in `exports/` to be:
|
||||
### 1. Setup (Once)
|
||||
|
||||
```bash
|
||||
./scripts/setup-python.sh
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
Windows:
|
||||
|
||||
```powershell
|
||||
./scripts/setup-python.ps1
|
||||
```
|
||||
|
||||
### 2. Build Agent (Claude Code)
|
||||
|
||||
```
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
Enter goal: "Build an agent that processes customer support tickets"
|
||||
```
|
||||
|
||||
### 3. Validate Agent
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent validate
|
||||
PYTHONPATH=core:exports python -m your_agent_name validate
|
||||
```
|
||||
|
||||
### 4. Test Agent
|
||||
@@ -276,7 +507,7 @@ claude> /testing-agent
|
||||
### 5. Run Agent
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent run --input '{...}'
|
||||
PYTHONPATH=core:exports python -m your_agent_name run --input '{...}'
|
||||
```
|
||||
|
||||
## IDE Setup
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
.PHONY: lint format check test install-hooks help
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \
|
||||
awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
lint: ## Run ruff linter (with auto-fix)
|
||||
cd core && ruff check --fix .
|
||||
cd tools && ruff check --fix .
|
||||
|
||||
format: ## Run ruff formatter
|
||||
cd core && ruff format .
|
||||
cd tools && ruff format .
|
||||
|
||||
check: ## Run all checks without modifying files (CI-safe)
|
||||
cd core && ruff check .
|
||||
cd tools && ruff check .
|
||||
cd core && ruff format --check .
|
||||
cd tools && ruff format --check .
|
||||
|
||||
test: ## Run all tests
|
||||
cd core && python -m pytest tests/ -v
|
||||
|
||||
install-hooks: ## Install pre-commit hooks
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
@@ -0,0 +1,51 @@
|
||||
## Summary
|
||||
- **Added HubSpot integration** — new HubSpot MCP tool with search, get, create, and update operations for contacts, companies, and deals. Includes OAuth2 provider for HubSpot credentials and credential store adapter for the tools layer.
|
||||
- **Replaced web_scrape tool with Playwright + stealth** — swapped httpx/BeautifulSoup for a headless Chromium browser using `playwright` (async API) and `playwright-stealth`, enabling JS-rendered page scraping and bot detection evasion
|
||||
- **Added empty response retry logic** — LLM provider now detects empty responses (e.g. Gemini returning 200 with no content on rate limit) and retries with exponential backoff, preventing hallucinated output from the cleanup LLM
|
||||
- **Added context-aware input compaction** — LLM nodes now estimate input token count before calling the model and progressively truncate the largest values if they exceed the context window budget
|
||||
- **Increased rate limit retries to 10** with verbose `[retry]` and `[compaction]` logging that includes model name, finish reason, and attempt count
|
||||
- **Updated setup scripts** — `scripts/setup-python.sh` now installs Playwright Chromium browser automatically for web scraping support
|
||||
- **Interactive quickstart onboarding** — `quickstart.sh` rewritten as bee-themed interactive wizard that detects existing API keys (including Claude Code subscription), lets user pick ONE default LLM provider, and saves configuration to `~/.hive/configuration.json`
|
||||
- **Fixed lint errors** across `hubspot_tool.py` (line length) and `agent_builder_server.py` (unused variable)
|
||||
|
||||
## Changed files
|
||||
|
||||
### HubSpot Integration
|
||||
- `tools/src/aden_tools/tools/hubspot_tool/` — New MCP tool: contacts, companies, and deals CRUD
|
||||
- `tools/src/aden_tools/tools/__init__.py` — Registered HubSpot tools
|
||||
- `tools/src/aden_tools/credentials/integrations.py` — HubSpot credential integration
|
||||
- `tools/src/aden_tools/credentials/__init__.py` — Updated credential exports
|
||||
- `core/framework/credentials/oauth2/hubspot_provider.py` — HubSpot OAuth2 provider
|
||||
- `core/framework/credentials/oauth2/__init__.py` — Registered HubSpot OAuth2 provider
|
||||
- `core/framework/runner/runner.py` — Updated runner for credential support
|
||||
|
||||
### Web Scrape Rewrite
|
||||
- `tools/src/aden_tools/tools/web_scrape_tool/web_scrape_tool.py` — Playwright async rewrite
|
||||
- `tools/src/aden_tools/tools/web_scrape_tool/README.md` — Updated docs
|
||||
- `tools/pyproject.toml` — Added `playwright`, `playwright-stealth` deps
|
||||
- `tools/Dockerfile` — Added `playwright install chromium --with-deps`
|
||||
- `scripts/setup-python.sh` — Added Playwright Chromium browser install step
|
||||
|
||||
### LLM Reliability
|
||||
- `core/framework/llm/litellm.py` — Empty response retry + max retries 10 + verbose logging
|
||||
- `core/framework/graph/node.py` — Input compaction via `_compact_inputs()`, `_estimate_tokens()`, `_get_context_limit()`
|
||||
|
||||
### Quickstart & Setup
|
||||
- `quickstart.sh` — Interactive bee-themed onboarding wizard with single provider selection
|
||||
- `~/.hive/configuration.json` — New user config file for default LLM provider/model
|
||||
|
||||
### Fixes
|
||||
- `core/framework/mcp/agent_builder_server.py` — Removed unused variable
|
||||
- `tools/src/aden_tools/tools/hubspot_tool/hubspot_tool.py` — Fixed E501 line length violations
|
||||
|
||||
## Test plan
|
||||
- [ ] Run `make lint` — passes clean
|
||||
- [ ] Run `./quickstart.sh` and verify interactive flow works, config saved to `~/.hive/configuration.json`
|
||||
- [ ] Run `./scripts/setup-python.sh` and verify Playwright Chromium installs
|
||||
- [ ] Run `pytest tests/tools/test_web_scrape_tool.py -v`
|
||||
- [ ] Run agent against a JS-heavy site and verify `web_scrape` returns rendered content
|
||||
- [ ] Set `HUBSPOT_ACCESS_TOKEN` and verify HubSpot tool CRUD operations work
|
||||
- [ ] Trigger rate limit and verify `[retry]` logs appear with correct attempt counts
|
||||
- [ ] Run agent with large inputs and verify `[compaction]` logs show truncation
|
||||
|
||||
🤖 Generated with [Claude Code](https://claude.com/claude-code)
|
||||
@@ -4,16 +4,17 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="README.md">English</a> |
|
||||
<a href="README.zh-CN.md">简体中文</a> |
|
||||
<a href="README.es.md">Español</a> |
|
||||
<a href="README.pt.md">Português</a> |
|
||||
<a href="README.ja.md">日本語</a> |
|
||||
<a href="README.ru.md">Русский</a>
|
||||
<a href="docs/i18n/zh-CN.md">简体中文</a> |
|
||||
<a href="docs/i18n/es.md">Español</a> |
|
||||
<a href="docs/i18n/hi.md">हिन्दी</a> |
|
||||
<a href="docs/i18n/pt.md">Português</a> |
|
||||
<a href="docs/i18n/ja.md">日本語</a> |
|
||||
<a href="docs/i18n/ru.md">Русский</a> |
|
||||
<a href="docs/i18n/ko.md">한국어</a>
|
||||
</p>
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
[](https://www.ycombinator.com/companies/aden)
|
||||
[](https://hub.docker.com/u/adenhq)
|
||||
[](https://discord.com/invite/MXE49hrKDk)
|
||||
[](https://x.com/aden_hq)
|
||||
[](https://www.linkedin.com/company/teamaden/)
|
||||
@@ -38,6 +39,31 @@ Build reliable, self-improving AI agents without hardcoding workflows. Define yo
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
Hive is designed for developers and teams who want to build **production-grade AI agents** without manually wiring complex workflows.
|
||||
|
||||
Hive is a good fit if you:
|
||||
|
||||
- Want AI agents that **execute real business processes**, not demos
|
||||
- Prefer **goal-driven development** over hardcoded workflows
|
||||
- Need **self-healing and adaptive agents** that improve over time
|
||||
- Require **human-in-the-loop control**, observability, and cost limits
|
||||
- Plan to run agents in **production environments**
|
||||
|
||||
Hive may not be the best fit if you’re only experimenting with simple agent chains or one-off scripts.
|
||||
|
||||
## When Should You Use Hive?
|
||||
|
||||
Use Hive when you need:
|
||||
|
||||
- Long-running, autonomous agents
|
||||
- Multi-agent coordination
|
||||
- Continuous improvement based on failures
|
||||
- Strong monitoring, safety, and budget controls
|
||||
- A framework that evolves with your goals
|
||||
|
||||
|
||||
## What is Aden
|
||||
|
||||
<p align="center">
|
||||
@@ -65,7 +91,7 @@ Aden is a platform for building, deploying, operating, and adapting AI agents:
|
||||
### Prerequisites
|
||||
|
||||
- [Python 3.11+](https://www.python.org/downloads/) for agent development
|
||||
- [Docker](https://docs.docker.com/get-docker/) (v20.10+) - Optional, for containerized tools
|
||||
- Claude Code or Cursor for utilizing agent skills
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -74,23 +100,20 @@ Aden is a platform for building, deploying, operating, and adapting AI agents:
|
||||
git clone https://github.com/adenhq/hive.git
|
||||
cd hive
|
||||
|
||||
# Run Python environment setup
|
||||
./scripts/setup-python.sh
|
||||
# Run quickstart setup
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
This installs:
|
||||
- **framework** - Core agent runtime and graph executor
|
||||
- **aden_tools** - 19 MCP tools for agent capabilities
|
||||
- All required dependencies
|
||||
This sets up:
|
||||
- **framework** - Core agent runtime and graph executor (in `core/.venv`)
|
||||
- **aden_tools** - MCP tools for agent capabilities (in `tools/.venv`)
|
||||
- All required Python dependencies
|
||||
|
||||
### Build Your First Agent
|
||||
|
||||
```bash
|
||||
# Install Claude Code skills (one-time)
|
||||
./quickstart.sh
|
||||
|
||||
# Build an agent using Claude Code
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Test your agent
|
||||
claude> /testing-agent
|
||||
@@ -101,10 +124,19 @@ PYTHONPATH=core:exports python -m your_agent_name run --input '{...}'
|
||||
|
||||
**[📖 Complete Setup Guide](ENVIRONMENT_SETUP.md)** - Detailed instructions for agent development
|
||||
|
||||
### Cursor IDE Support
|
||||
|
||||
Skills are also available in Cursor. To enable:
|
||||
|
||||
1. Open Command Palette (`Cmd+Shift+P` / `Ctrl+Shift+P`)
|
||||
2. Run `MCP: Enable` to enable MCP servers
|
||||
3. Restart Cursor to load the MCP servers from `.cursor/mcp.json`
|
||||
4. Type `/` in Agent chat and search for skills (e.g., `/building-agents-construction`)
|
||||
|
||||
## Features
|
||||
|
||||
- **Goal-Driven Development** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
- **Self-Adapting Agents** - Framework captures failures, updates objectives and updates the agent graph
|
||||
- **Adaptiveness** - Framework captures failures, calibrates according to the objectives, and evolves the agent graph
|
||||
- **Dynamic Node Connections** - No predefined edges; connection code is generated by any capable LLM based on your goals
|
||||
- **SDK-Wrapped Nodes** - Every node gets shared memory, local RLM memory, monitoring, tools, and LLM access out of the box
|
||||
- **Human-in-the-Loop** - Intervention nodes that pause execution for human input with configurable timeouts and escalation
|
||||
@@ -114,51 +146,38 @@ PYTHONPATH=core:exports python -m your_agent_name run --input '{...}'
|
||||
|
||||
## Why Aden
|
||||
|
||||
Traditional agent frameworks require you to manually design workflows, define agent interactions, and handle failures reactively. Aden flips this paradigm—**you describe outcomes, and the system builds itself**.
|
||||
Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph BUILD["🏗️ BUILD"]
|
||||
GOAL["Define Goal<br/>+ Success Criteria"] --> NODES["Add Nodes<br/>LLM/Router/Function"]
|
||||
NODES --> EDGES["Connect Edges<br/>on_success/failure/conditional"]
|
||||
EDGES --> TEST["Test & Validate"] --> APPROVE["Approve & Export"]
|
||||
end
|
||||
GOAL["Define Goal"] --> GEN["Auto-Generate Graph"]
|
||||
GEN --> EXEC["Execute Agents"]
|
||||
EXEC --> MON["Monitor & Observe"]
|
||||
MON --> CHECK{{"Pass?"}}
|
||||
CHECK -- "Yes" --> DONE["Deliver Result"]
|
||||
CHECK -- "No" --> EVOLVE["Evolve Graph"]
|
||||
EVOLVE --> EXEC
|
||||
|
||||
subgraph EXPORT["📦 EXPORT"]
|
||||
direction TB
|
||||
JSON["agent.json<br/>(GraphSpec)"]
|
||||
TOOLS["tools.py<br/>(Functions)"]
|
||||
MCP["mcp_servers.json<br/>(Integrations)"]
|
||||
end
|
||||
GOAL -.- V1["Natural Language"]
|
||||
GEN -.- V2["Instant Architecture"]
|
||||
EXEC -.- V3["Easy Integrations"]
|
||||
MON -.- V4["Full visibility"]
|
||||
EVOLVE -.- V5["Adaptability"]
|
||||
DONE -.- V6["Reliable outcomes"]
|
||||
|
||||
subgraph RUN["🚀 RUNTIME"]
|
||||
LOAD["AgentRunner<br/>Load + Parse"] --> SETUP["Setup Runtime<br/>+ ToolRegistry"]
|
||||
SETUP --> EXEC["GraphExecutor<br/>Execute Nodes"]
|
||||
|
||||
subgraph DECISION["Decision Recording"]
|
||||
DEC1["runtime.decide()<br/>intent → options → choice"]
|
||||
DEC2["runtime.record_outcome()<br/>success, result, metrics"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph INFRA["⚙️ INFRASTRUCTURE"]
|
||||
CTX["NodeContext<br/>memory • llm • tools"]
|
||||
STORE[("FileStorage<br/>Runs & Decisions")]
|
||||
end
|
||||
|
||||
APPROVE --> EXPORT
|
||||
EXPORT --> LOAD
|
||||
EXEC --> DECISION
|
||||
EXEC --> CTX
|
||||
DECISION --> STORE
|
||||
STORE -.->|"Analyze & Improve"| NODES
|
||||
|
||||
style BUILD fill:#ffbe42,stroke:#cc5d00,stroke-width:3px,color:#333
|
||||
style EXPORT fill:#fff59d,stroke:#ed8c00,stroke-width:2px,color:#333
|
||||
style RUN fill:#ffb100,stroke:#cc5d00,stroke-width:3px,color:#333
|
||||
style DECISION fill:#ffcc80,stroke:#ed8c00,stroke-width:2px,color:#333
|
||||
style INFRA fill:#e8763d,stroke:#cc5d00,stroke-width:3px,color:#fff
|
||||
style STORE fill:#ed8c00,stroke:#cc5d00,stroke-width:2px,color:#fff
|
||||
style GOAL fill:#ffbe42,stroke:#cc5d00,stroke-width:2px,color:#333
|
||||
style GEN fill:#ffb100,stroke:#cc5d00,stroke-width:2px,color:#333
|
||||
style EXEC fill:#ff9800,stroke:#cc5d00,stroke-width:2px,color:#fff
|
||||
style MON fill:#ff9800,stroke:#cc5d00,stroke-width:2px,color:#fff
|
||||
style CHECK fill:#fff59d,stroke:#ed8c00,stroke-width:2px,color:#333
|
||||
style DONE fill:#4caf50,stroke:#2e7d32,stroke-width:2px,color:#fff
|
||||
style EVOLVE fill:#e8763d,stroke:#cc5d00,stroke-width:2px,color:#fff
|
||||
style V1 fill:#fff,stroke:#ed8c00,stroke-width:1px,color:#cc5d00
|
||||
style V2 fill:#fff,stroke:#ed8c00,stroke-width:1px,color:#cc5d00
|
||||
style V3 fill:#fff,stroke:#ed8c00,stroke-width:1px,color:#cc5d00
|
||||
style V4 fill:#fff,stroke:#ed8c00,stroke-width:1px,color:#cc5d00
|
||||
style V5 fill:#fff,stroke:#ed8c00,stroke-width:1px,color:#cc5d00
|
||||
style V6 fill:#fff,stroke:#ed8c00,stroke-width:1px,color:#cc5d00
|
||||
```
|
||||
|
||||
### The Aden Advantage
|
||||
@@ -167,7 +186,7 @@ flowchart LR
|
||||
| -------------------------- | -------------------------------------- |
|
||||
| Hardcode agent workflows | Describe goals in natural language |
|
||||
| Manual graph definition | Auto-generated agent graphs |
|
||||
| Reactive error handling | Proactive self-evolution |
|
||||
| Reactive error handling | Outcome-evaluation and adaptiveness |
|
||||
| Static tool configurations | Dynamic SDK-wrapped nodes |
|
||||
| Separate monitoring setup | Built-in real-time observability |
|
||||
| DIY budget management | Integrated cost controls & degradation |
|
||||
@@ -178,76 +197,30 @@ flowchart LR
|
||||
2. **Coding Agent Generates** → Creates the agent graph, connection code, and test cases
|
||||
3. **Workers Execute** → SDK-wrapped nodes run with full observability and tool access
|
||||
4. **Control Plane Monitors** → Real-time metrics, budget enforcement, policy management
|
||||
5. **Self-Improve** → On failure, the system evolves the graph and redeploys automatically
|
||||
5. **Adaptiveness** → On failure, the system evolves the graph and redeploys automatically
|
||||
|
||||
## How Aden Compares
|
||||
## Run pre-built Agents (Coming Soon)
|
||||
|
||||
Aden takes a fundamentally different approach to agent development. While most frameworks require you to hardcode workflows or manually define agent graphs, Aden uses a **coding agent to generate your entire agent system** from natural language goals. When agents fail, the framework doesn't just log errors—it **automatically evolves the agent graph** and redeploys.
|
||||
### Run a sample agent
|
||||
Aden Hive provides a list of featured agents that you can use and build on top of.
|
||||
|
||||
### Comparison Table
|
||||
### Run an agent shared by others
|
||||
Put the agent in `exports/` and run `PYTHONPATH=core:exports python -m your_agent_name run --input '{...}'`
|
||||
|
||||
| Framework | Category | Approach | Aden Difference |
|
||||
| ----------------------------------- | ------------------------- | --------------------------------------------------------------- | --------------------------------------------------------- |
|
||||
| **LangChain, LlamaIndex, Haystack** | Component Libraries | Predefined components for RAG/LLM apps; manual connection logic | Generates entire graph and connection code upfront |
|
||||
| **CrewAI, AutoGen, Swarm** | Multi-Agent Orchestration | Role-based agents with predefined collaboration patterns | Dynamically creates agents/connections; adapts on failure |
|
||||
| **PydanticAI, Mastra, Agno** | Type-Safe Frameworks | Structured outputs and validation for known workflows | Evolving workflows; structure emerges through iteration |
|
||||
| **Agent Zero, Letta** | Personal AI Assistants | Memory and learning; OS-as-tool or stateful memory focus | Production multi-agent systems with self-healing |
|
||||
| **CAMEL** | Research Framework | Emergent behavior in large-scale simulations (up to 1M agents) | Production-oriented with reliable execution and recovery |
|
||||
| **TEN Framework, Genkit** | Infrastructure Frameworks | Real-time multimodal (TEN) or full-stack AI (Genkit) | Higher abstraction—generates and evolves agent logic |
|
||||
| **GPT Engineer, Motia** | Code Generation | Code from specs (GPT Engineer) or "Step" primitive (Motia) | Self-adapting graphs with automatic failure recovery |
|
||||
| **Trading Agents** | Domain-Specific | Hardcoded trading firm roles on LangGraph | Domain-agnostic; generates structures for any use case |
|
||||
|
||||
### When to Choose Aden
|
||||
|
||||
Choose Aden when you need:
|
||||
|
||||
- Agents that **self-improve from failures** without manual intervention
|
||||
- **Goal-driven development** where you describe outcomes, not workflows
|
||||
- **Production reliability** with automatic recovery and redeployment
|
||||
- **Rapid iteration** on agent architectures without rewriting code
|
||||
- **Full observability** with real-time monitoring and human oversight
|
||||
|
||||
Choose other frameworks when you need:
|
||||
|
||||
- **Type-safe, predictable workflows** (PydanticAI, Mastra)
|
||||
- **RAG and document processing** (LlamaIndex, Haystack)
|
||||
- **Research on agent emergence** (CAMEL)
|
||||
- **Real-time voice/multimodal** (TEN Framework)
|
||||
- **Simple component chaining** (LangChain, Swarm)
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
hive/
|
||||
├── core/ # Core framework - Agent runtime, graph executor, protocols
|
||||
├── tools/ # MCP Tools Package - 19 tools for agent capabilities
|
||||
├── exports/ # Agent packages - Pre-built agents and examples
|
||||
├── docs/ # Documentation and guides
|
||||
├── scripts/ # Build and utility scripts
|
||||
├── .claude/ # Claude Code skills for building agents
|
||||
├── ENVIRONMENT_SETUP.md # Python setup guide for agent development
|
||||
├── DEVELOPER.md # Developer guide
|
||||
├── CONTRIBUTING.md # Contribution guidelines
|
||||
└── ROADMAP.md # Product roadmap
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Python Agent Development
|
||||
|
||||
For building and running goal-driven agents with the framework:
|
||||
|
||||
```bash
|
||||
# One-time setup
|
||||
./scripts/setup-python.sh
|
||||
./quickstart.sh
|
||||
|
||||
# This installs:
|
||||
# This sets up:
|
||||
# - framework package (core runtime)
|
||||
# - aden_tools package (19 MCP tools)
|
||||
# - All dependencies
|
||||
# - aden_tools package (MCP tools)
|
||||
# - All Python dependencies
|
||||
|
||||
# Build new agents using Claude Code skills
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Test agents
|
||||
claude> /testing-agent
|
||||
@@ -263,29 +236,112 @@ See [ENVIRONMENT_SETUP.md](ENVIRONMENT_SETUP.md) for complete setup instructions
|
||||
- **[Developer Guide](DEVELOPER.md)** - Comprehensive guide for developers
|
||||
- [Getting Started](docs/getting-started.md) - Quick setup instructions
|
||||
- [Configuration Guide](docs/configuration.md) - All configuration options
|
||||
- [Architecture Overview](docs/architecture.md) - System design and structure
|
||||
- [Architecture Overview](docs/architecture/README.md) - System design and structure
|
||||
|
||||
## Roadmap
|
||||
|
||||
Aden Agent Framework aims to help developers build outcome oriented, self-adaptive agents. Please find our roadmap here
|
||||
|
||||
[ROADMAP.md](ROADMAP.md)
|
||||
Aden Hive Agent Framework aims to help developers build outcome-oriented, self-adaptive agents. See [ROADMAP.md](ROADMAP.md) for details.
|
||||
|
||||
```mermaid
|
||||
timeline
|
||||
title Aden Agent Framework Roadmap
|
||||
section Foundation
|
||||
Architecture : Node-Based Architecture : Python SDK : LLM Integration (OpenAI, Anthropic, Google) : Communication Protocol
|
||||
Coding Agent : Goal Creation Session : Worker Agent Creation : MCP Tools Integration
|
||||
Worker Agent : Human-in-the-Loop : Callback Handlers : Intervention Points : Streaming Interface
|
||||
Tools : File Use : Memory (STM/LTM) : Web Search : Web Scraper : Audit Trail
|
||||
Core : Eval System : Pydantic Validation : Docker Deployment : Documentation : Sample Agents
|
||||
section Expansion
|
||||
Intelligence : Guardrails : Streaming Mode : Semantic Search
|
||||
Platform : JavaScript SDK : Custom Tool Integrator : Credential Store
|
||||
Deployment : Self-Hosted : Cloud Services : CI/CD Pipeline
|
||||
Templates : Sales Agent : Marketing Agent : Analytics Agent : Training Agent : Smart Form Agent
|
||||
flowchart TD
|
||||
subgraph Foundation
|
||||
direction LR
|
||||
subgraph arch["Architecture"]
|
||||
a1["Node-Based Architecture"]:::done
|
||||
a2["Python SDK"]:::done
|
||||
a3["LLM Integration"]:::done
|
||||
a4["Communication Protocol"]:::done
|
||||
end
|
||||
subgraph ca["Coding Agent"]
|
||||
b1["Goal Creation Session"]:::done
|
||||
b2["Worker Agent Creation"]
|
||||
b3["MCP Tools"]:::done
|
||||
end
|
||||
subgraph wa["Worker Agent"]
|
||||
c1["Human-in-the-Loop"]:::done
|
||||
c2["Callback Handlers"]:::done
|
||||
c3["Intervention Points"]:::done
|
||||
c4["Streaming Interface"]
|
||||
end
|
||||
subgraph cred["Credentials"]
|
||||
d1["Setup Process"]:::done
|
||||
d2["Pluggable Sources"]:::done
|
||||
d3["Enterprise Secrets"]
|
||||
d4["Integration Tools"]:::done
|
||||
end
|
||||
subgraph tools["Tools"]
|
||||
e1["File Use"]:::done
|
||||
e2["Memory STM/LTM"]:::done
|
||||
e3["Web Search/Scraper"]:::done
|
||||
e4["CSV/PDF"]:::done
|
||||
e5["Excel/Email"]
|
||||
end
|
||||
subgraph core["Core"]
|
||||
f1["Eval System"]
|
||||
f2["Pydantic Validation"]:::done
|
||||
f3["Documentation"]:::done
|
||||
f4["Adaptiveness"]
|
||||
f5["Sample Agents"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph Expansion
|
||||
direction LR
|
||||
subgraph intel["Intelligence"]
|
||||
g1["Guardrails"]
|
||||
g2["Streaming Mode"]
|
||||
g3["Image Generation"]
|
||||
g4["Semantic Search"]
|
||||
end
|
||||
subgraph mem["Memory Iteration"]
|
||||
h1["Message Model & Sessions"]
|
||||
h2["Storage Migration"]
|
||||
h3["Context Building"]
|
||||
h4["Proactive Compaction"]
|
||||
h5["Token Tracking"]
|
||||
end
|
||||
subgraph evt["Event System"]
|
||||
i1["Event Bus for Nodes"]
|
||||
end
|
||||
subgraph cas["Coding Agent Support"]
|
||||
j1["Claude Code"]
|
||||
j2["Cursor"]
|
||||
j3["Opencode"]
|
||||
j4["Antigravity"]
|
||||
end
|
||||
subgraph plat["Platform"]
|
||||
k1["JavaScript/TypeScript SDK"]
|
||||
k2["Custom Tool Integrator"]
|
||||
k3["Windows Support"]
|
||||
end
|
||||
subgraph dep["Deployment"]
|
||||
l1["Self-Hosted"]
|
||||
l2["Cloud Services"]
|
||||
l3["CI/CD Pipeline"]
|
||||
end
|
||||
subgraph tmpl["Templates"]
|
||||
m1["Sales Agent"]
|
||||
m2["Marketing Agent"]
|
||||
m3["Analytics Agent"]
|
||||
m4["Training Agent"]
|
||||
m5["Smart Form Agent"]
|
||||
end
|
||||
end
|
||||
|
||||
classDef done fill:#9e9e9e,color:#fff,stroke:#757575
|
||||
```
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! We’re especially looking for help building tools, integrations, and example agents for the framework ([check #2805](https://github.com/adenhq/hive/issues/2805)). If you’re interested in extending its functionality, this is the perfect place to start. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
|
||||
**Important:** Please get assigned to an issue before submitting a PR. Comment on an issue to claim it, and a maintainer will assign you. Issues with reproducible steps and proposals are prioritized. This helps prevent duplicate work.
|
||||
|
||||
1. Find or create an issue and get assigned
|
||||
2. Fork the repository
|
||||
3. Create your feature branch (`git checkout -b feature/amazing-feature`)
|
||||
4. Commit your changes (`git commit -m 'Add amazing feature'`)
|
||||
5. Push to the branch (`git push origin feature/amazing-feature`)
|
||||
6. Open a Pull Request
|
||||
|
||||
## Community & Support
|
||||
|
||||
@@ -295,16 +351,6 @@ We use [Discord](https://discord.com/invite/MXE49hrKDk) for support, feature req
|
||||
- Twitter/X - [@adenhq](https://x.com/aden_hq)
|
||||
- LinkedIn - [Company Page](https://www.linkedin.com/company/teamaden/)
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
|
||||
1. Fork the repository
|
||||
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
|
||||
3. Commit your changes (`git commit -m 'Add amazing feature'`)
|
||||
4. Push to the branch (`git push origin feature/amazing-feature`)
|
||||
5. Open a Pull Request
|
||||
|
||||
## Join Our Team
|
||||
|
||||
**We're hiring!** Join us in engineering, research, and go-to-market roles.
|
||||
@@ -321,57 +367,57 @@ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENS
|
||||
|
||||
## Frequently Asked Questions (FAQ)
|
||||
|
||||
**Q: Does Aden depend on LangChain or other agent frameworks?**
|
||||
**Q: Does Hive depend on LangChain or other agent frameworks?**
|
||||
|
||||
No. Aden is built from the ground up with no dependencies on LangChain, CrewAI, or other agent frameworks. The framework is designed to be lean and flexible, generating agent graphs dynamically rather than relying on predefined components.
|
||||
No. Hive is built from the ground up with no dependencies on LangChain, CrewAI, or other agent frameworks. The framework is designed to be lean and flexible, generating agent graphs dynamically rather than relying on predefined components.
|
||||
|
||||
**Q: What LLM providers does Aden support?**
|
||||
**Q: What LLM providers does Hive support?**
|
||||
|
||||
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
|
||||
**Q: Can I use Aden with local AI models like Ollama?**
|
||||
**Q: Can I use Hive with local AI models like Ollama?**
|
||||
|
||||
Yes! Aden supports local models through LiteLLM. Simply use the model name format `ollama/model-name` (e.g., `ollama/llama3`, `ollama/mistral`) and ensure Ollama is running locally.
|
||||
Yes! Hive supports local models through LiteLLM. Simply use the model name format `ollama/model-name` (e.g., `ollama/llama3`, `ollama/mistral`) and ensure Ollama is running locally.
|
||||
|
||||
**Q: What makes Aden different from other agent frameworks?**
|
||||
**Q: What makes Hive different from other agent frameworks?**
|
||||
|
||||
Aden generates your entire agent system from natural language goals using a coding agent—you don't hardcode workflows or manually define graphs. When agents fail, the framework automatically captures failure data, evolves the agent graph, and redeploys. This self-improving loop is unique to Aden.
|
||||
Hive generates your entire agent system from natural language goals using a coding agent—you don't hardcode workflows or manually define graphs. When agents fail, the framework automatically captures failure data, evolves the agent graph, and redeploys. This self-improving loop is unique to Aden.
|
||||
|
||||
**Q: Is Aden open-source?**
|
||||
**Q: Is Hive open-source?**
|
||||
|
||||
Yes, Aden is fully open-source under the Apache License 2.0. We actively encourage community contributions and collaboration.
|
||||
Yes, Hive is fully open-source under the Apache License 2.0. We actively encourage community contributions and collaboration.
|
||||
|
||||
**Q: Does Aden collect data from users?**
|
||||
**Q: Does Hive collect data from users?**
|
||||
|
||||
Aden collects telemetry data for monitoring and observability purposes, including token usage, latency metrics, and cost tracking. Content capture (prompts and responses) is configurable and stored with team-scoped data isolation. All data stays within your infrastructure when self-hosted.
|
||||
Hive collects telemetry data for monitoring and observability purposes, including token usage, latency metrics, and cost tracking. Content capture (prompts and responses) is configurable and stored with team-scoped data isolation. All data stays within your infrastructure when self-hosted.
|
||||
|
||||
**Q: What deployment options does Aden support?**
|
||||
**Q: What deployment options does Hive support?**
|
||||
|
||||
Aden supports Docker Compose deployment out of the box, with both production and development configurations. Self-hosted deployments work on any infrastructure supporting Docker. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
|
||||
Hive supports self-hosted deployments via Python packages. See the [Environment Setup Guide](ENVIRONMENT_SETUP.md) for installation instructions. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
|
||||
|
||||
**Q: Can Aden handle complex, production-scale use cases?**
|
||||
**Q: Can Hive handle complex, production-scale use cases?**
|
||||
|
||||
Yes. Aden is explicitly designed for production environments with features like automatic failure recovery, real-time observability, cost controls, and horizontal scaling support. The framework handles both simple automations and complex multi-agent workflows.
|
||||
Yes. Hive is explicitly designed for production environments with features like automatic failure recovery, real-time observability, cost controls, and horizontal scaling support. The framework handles both simple automations and complex multi-agent workflows.
|
||||
|
||||
**Q: Does Aden support human-in-the-loop workflows?**
|
||||
**Q: Does Hive support human-in-the-loop workflows?**
|
||||
|
||||
Yes, Aden fully supports human-in-the-loop workflows through intervention nodes that pause execution for human input. These include configurable timeouts and escalation policies, allowing seamless collaboration between human experts and AI agents.
|
||||
Yes, Hive fully supports human-in-the-loop workflows through intervention nodes that pause execution for human input. These include configurable timeouts and escalation policies, allowing seamless collaboration between human experts and AI agents.
|
||||
|
||||
**Q: What monitoring and debugging tools does Aden provide?**
|
||||
**Q: What monitoring and debugging tools does Hive provide?**
|
||||
|
||||
Aden includes comprehensive observability features: real-time WebSocket streaming for live agent execution monitoring, TimescaleDB-powered analytics for cost and performance metrics, health check endpoints for Kubernetes integration, and 19 MCP tools for budget management, agent status, and policy control.
|
||||
Hive includes comprehensive observability features: real-time WebSocket streaming for live agent execution monitoring, TimescaleDB-powered analytics for cost and performance metrics, health check endpoints for Kubernetes integration, and MCP tools for agent execution, including file operations, web search, data processing, and more.
|
||||
|
||||
**Q: What programming languages does Aden support?**
|
||||
**Q: What programming languages does Hive support?**
|
||||
|
||||
Aden provides SDKs for both Python and JavaScript/TypeScript. The Python SDK includes integration templates for LangGraph, LangFlow, and LiveKit. The backend is Node.js/TypeScript, and the frontend is React/TypeScript.
|
||||
The Hive framework is built in Python. A JavaScript/TypeScript SDK is on the roadmap.
|
||||
|
||||
**Q: Can Aden agents interact with external tools and APIs?**
|
||||
|
||||
Yes. Aden's SDK-wrapped nodes provide built-in tool access, and the framework supports flexible tool ecosystems. Agents can integrate with external APIs, databases, and services through the node architecture.
|
||||
|
||||
**Q: How does cost control work in Aden?**
|
||||
**Q: How does cost control work in Hive?**
|
||||
|
||||
Aden provides granular budget controls including spending limits, throttles, and automatic model degradation policies. You can set budgets at the team, agent, or workflow level, with real-time cost tracking and alerts.
|
||||
Hive provides granular budget controls including spending limits, throttles, and automatic model degradation policies. You can set budgets at the team, agent, or workflow level, with real-time cost tracking and alerts.
|
||||
|
||||
**Q: Where can I find examples and documentation?**
|
||||
|
||||
@@ -381,6 +427,14 @@ Visit [docs.adenhq.com](https://docs.adenhq.com/) for complete guides, API refer
|
||||
|
||||
Contributions are welcome! Fork the repository, create your feature branch, implement your changes, and submit a pull request. See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines.
|
||||
|
||||
**Q: When will my team start seeing results from Aden's adaptive agents?**
|
||||
|
||||
Aden's adaptation loop begins working from the first execution. When an agent fails, the framework captures the failure data, helping developers evolve the agent graph through the coding agent. How quickly this translates to measurable results depends on the complexity of your use case, the quality of your goal definitions, and the volume of executions generating feedback.
|
||||
|
||||
**Q: How does Hive compare to other agent frameworks?**
|
||||
|
||||
Hive focuses on generating agents that run real business processes, rather than generic agents. This vision emphasizes outcome-driven design, adaptability, and an easy-to-use set of tools and integrations.
|
||||
|
||||
**Q: Does Aden offer enterprise support?**
|
||||
|
||||
For enterprise inquiries, contact the Aden team through [adenhq.com](https://adenhq.com) or join our [Discord community](https://discord.com/invite/MXE49hrKDk) for support and discussions.
|
||||
|
||||
+191
-42
@@ -1,21 +1,94 @@
|
||||
Product Roadmap
|
||||
# Product Roadmap
|
||||
|
||||
Aden Agent Framework aims to help developers build outcome oriented, self-adaptive agents. Please find our roadmap here
|
||||
|
||||
```mermaid
|
||||
timeline
|
||||
title Aden Agent Framework Roadmap
|
||||
section Foundation
|
||||
Architecture : Node-Based Architecture : Python SDK : LLM Integration (OpenAI, Anthropic, Google) : Communication Protocol
|
||||
Coding Agent : Goal Creation Session : Worker Agent Creation : MCP Tools Integration
|
||||
Worker Agent : Human-in-the-Loop : Callback Handlers : Intervention Points : Streaming Interface
|
||||
Tools : File Use : Memory (STM/LTM) : Web Search : Web Scraper : Audit Trail
|
||||
Core : Eval System : Pydantic Validation : Docker Deployment : Documentation : Sample Agents
|
||||
section Expansion
|
||||
Intelligence : Guardrails : Streaming Mode : Semantic Search
|
||||
Platform : JavaScript SDK : Custom Tool Integrator : Credential Store
|
||||
Deployment : Self-Hosted : Cloud Services : CI/CD Pipeline
|
||||
Templates : Sales Agent : Marketing Agent : Analytics Agent : Training Agent : Smart Form Agent
|
||||
flowchart TD
|
||||
subgraph Foundation
|
||||
direction LR
|
||||
subgraph arch["Architecture"]
|
||||
a1["Node-Based Architecture"]:::done
|
||||
a2["Python SDK"]:::done
|
||||
a3["LLM Integration"]:::done
|
||||
a4["Communication Protocol"]:::done
|
||||
end
|
||||
subgraph ca["Coding Agent"]
|
||||
b1["Goal Creation Session"]:::done
|
||||
b2["Worker Agent Creation"]
|
||||
b3["MCP Tools"]:::done
|
||||
end
|
||||
subgraph wa["Worker Agent"]
|
||||
c1["Human-in-the-Loop"]:::done
|
||||
c2["Callback Handlers"]:::done
|
||||
c3["Intervention Points"]:::done
|
||||
c4["Streaming Interface"]
|
||||
end
|
||||
subgraph cred["Credentials"]
|
||||
d1["Setup Process"]:::done
|
||||
d2["Pluggable Sources"]:::done
|
||||
d3["Enterprise Secrets"]
|
||||
d4["Integration Tools"]:::done
|
||||
end
|
||||
subgraph tools["Tools"]
|
||||
e1["File Use"]:::done
|
||||
e2["Memory STM/LTM"]:::done
|
||||
e3["Web Search/Scraper"]:::done
|
||||
e4["CSV/PDF"]:::done
|
||||
e5["Excel/Email"]
|
||||
end
|
||||
subgraph core["Core"]
|
||||
f1["Eval System"]
|
||||
f2["Pydantic Validation"]:::done
|
||||
f3["Documentation"]:::done
|
||||
f4["Adaptiveness"]
|
||||
f5["Sample Agents"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph Expansion
|
||||
direction LR
|
||||
subgraph intel["Intelligence"]
|
||||
g1["Guardrails"]
|
||||
g2["Streaming Mode"]
|
||||
g3["Image Generation"]
|
||||
g4["Semantic Search"]
|
||||
end
|
||||
subgraph mem["Memory Iteration"]
|
||||
h1["Message Model & Sessions"]
|
||||
h2["Storage Migration"]
|
||||
h3["Context Building"]
|
||||
h4["Proactive Compaction"]
|
||||
h5["Token Tracking"]
|
||||
end
|
||||
subgraph evt["Event System"]
|
||||
i1["Event Bus for Nodes"]
|
||||
end
|
||||
subgraph cas["Coding Agent Support"]
|
||||
j1["Claude Code"]
|
||||
j2["Cursor"]
|
||||
j3["Opencode"]
|
||||
j4["Antigravity"]
|
||||
end
|
||||
subgraph plat["Platform"]
|
||||
k1["JavaScript/TypeScript SDK"]
|
||||
k2["Custom Tool Integrator"]
|
||||
k3["Windows Support"]
|
||||
end
|
||||
subgraph dep["Deployment"]
|
||||
l1["Self-Hosted"]
|
||||
l2["Cloud Services"]
|
||||
l3["CI/CD Pipeline"]
|
||||
end
|
||||
subgraph tmpl["Templates"]
|
||||
m1["Sales Agent"]
|
||||
m2["Marketing Agent"]
|
||||
m3["Analytics Agent"]
|
||||
m4["Training Agent"]
|
||||
m5["Smart Form Agent"]
|
||||
end
|
||||
end
|
||||
|
||||
classDef done fill:#9e9e9e,color:#fff,stroke:#757575
|
||||
```
|
||||
|
||||
---
|
||||
@@ -26,19 +99,19 @@ timeline
|
||||
- [ ] **Node-Based Architecture (Agent as a node)**
|
||||
- [x] Object schema definition
|
||||
- [x] Node wrapper SDK
|
||||
- [ ] Shared memory access
|
||||
- [x] Shared memory access
|
||||
- [ ] Default monitoring hooks
|
||||
- [ ] Tool access layer
|
||||
- [x] Tool access layer
|
||||
- [x] LLM integration layer (Natively supports all mainstream LLMs through LiteLLM)
|
||||
- [x] Anthropic
|
||||
- [x] OpenAI
|
||||
- [x] Google
|
||||
- [ ] **Communication protocol between nodes**
|
||||
- [ ] **[Coding Agent] Goal Creation Session** (separate from coding session)
|
||||
- [ ] Instruction back and forth
|
||||
- [x] **Communication protocol between nodes**
|
||||
- [x] **[Coding Agent] Goal Creation Session** (separate from coding session)
|
||||
- [x] Instruction back and forth
|
||||
- [x] Goal Object schema definition
|
||||
- [ ] Being able to generate the test cases
|
||||
- [ ] Test case validation for worker agent (Outcome driven)
|
||||
- [x] Being able to generate the test cases
|
||||
- [x] Test case validation for worker agent (Outcome driven)
|
||||
- [ ] **[Coding Agent] Worker Agent Creation**
|
||||
- [x] Coding Agent tools
|
||||
- [ ] Use Template Agent as a start
|
||||
@@ -46,21 +119,62 @@ timeline
|
||||
- [ ] **[Worker Agent] Human-in-the-Loop**
|
||||
- [x] Worker Agents request with questions and options
|
||||
- [x] Callback Handler System to receive events throughout execution
|
||||
- [ ] Tool-Based Intervention Points (tool to pause execution and request human input)
|
||||
- [x] Tool-Based Intervention Points (tool to pause execution and request human input)
|
||||
- [x] Multiple entrypoint for different event source (e.g. Human input, webhook)
|
||||
- [ ] Streaming Interface for Real-time Monitoring
|
||||
- [ ] Request State Management
|
||||
- [x] Request State Management
|
||||
|
||||
### Credential Management
|
||||
- [x] **Credentials Setup Process**
|
||||
- [x] Install Credential MCP
|
||||
- [x] **Pluggable Credential Sources**
|
||||
- [x] **Abstraction & Local Sources**
|
||||
- [x] Introduce `CredentialSource` base class
|
||||
- [x] Refactor existing logic into `EnvVarSource`
|
||||
- [x] Implementation of Source Priority Chain mechanism
|
||||
- [ ] Foundation unit tests
|
||||
- [ ] **Enterprise Secret Managers**
|
||||
- [x] `VaultSource` (HashiCorp Vault)
|
||||
- [ ] `AWSSecretsSource` (AWS Secrets Manager)
|
||||
- [ ] `AzureKeyVaultSource` (Azure Key Vault)
|
||||
- [ ] Management of optional provider dependencies
|
||||
- [ ] **Advanced Features**
|
||||
- [x] Credential expiration and auto-refresh
|
||||
- [ ] Audit logging for compliance/tracking
|
||||
- [ ] Per-environment configuration support
|
||||
- [ ] **Documentation & DX**
|
||||
- [ ] Comprehensive source documentation
|
||||
- [ ] Example configurations for all providers
|
||||
- [x] **Integration as tools coverage**
|
||||
- [x] Gsuite Tools
|
||||
- [x] Social Media
|
||||
- [ ] Twitter(X)
|
||||
- [x] Github
|
||||
- [ ] Instagram
|
||||
- [ ] SAAS
|
||||
- [ ] Hubspot
|
||||
- [ ] Slack
|
||||
- [ ] Teams
|
||||
- [ ] Zoom
|
||||
- [ ] Stripe
|
||||
- [ ] Salesforce
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Community Contribution Wanted**: We appreciate help from the community to expand the "Integration as tools" capability. Leave an issue of the integration you want to support via Hive!
|
||||
|
||||
### Essential Tools
|
||||
- [x] **File Use Tool Kit**
|
||||
- [ ] **Memory Tools**
|
||||
- [X] **Memory Tools**
|
||||
- [x] STM Layer Tool (state-based short-term memory)
|
||||
- [x] LTM Layer Tool (RLM - long-term memory)
|
||||
- [ ] **Infrastructure Tools**
|
||||
- [x] Runtime Log Tool (logs for coding agent)
|
||||
- [ ] Audit Trail Tool (decision timeline generation)
|
||||
- [ ] Web Search
|
||||
- [ ] Web Scraper
|
||||
- [x] Web Search
|
||||
- [x] Web Scraper
|
||||
- [x] CSV tools
|
||||
- [x] PDF tools
|
||||
- [ ] Excel tools
|
||||
- [ ] Email Tools
|
||||
- [ ] Recipe for "Add your own tools"
|
||||
|
||||
### Memory & File System
|
||||
@@ -75,20 +189,25 @@ timeline
|
||||
- [ ] User-driven log analysis (OSS approach)
|
||||
|
||||
### Data Validation
|
||||
- [ ] Natively Support data validation of LLMs output with Pydantic
|
||||
- [x] Natively Support data validation of LLMs output with Pydantic
|
||||
|
||||
### Developer Experience
|
||||
- [ ] **Debugging mode**
|
||||
- [ ] **Documentation**
|
||||
- [ ] Quick start guide
|
||||
- [ ] Goal creation guide
|
||||
- [ ] Agent creation guide
|
||||
- [ ] GitHub Page setup
|
||||
- [ ] README with examples
|
||||
- [ ] Contributing guidelines
|
||||
- [ ] **Distribution**
|
||||
- [ ] PyPI package
|
||||
- [ ] Docker image on Docker Hub
|
||||
- [ ] **MVP Features**
|
||||
- [ ] Debugging mode
|
||||
- [ ] CLI tools for memory management
|
||||
- [ ] CLI tools for credential management
|
||||
- [ ] **MVP Resources & Documentation**
|
||||
- [x] Quick start guide
|
||||
- [x] Goal creation guide
|
||||
- [x] Agent creation guide
|
||||
- [x] GitHub Page setup
|
||||
- [x] README with examples
|
||||
- [x] Contributing guidelines
|
||||
- [ ] Introduction Video
|
||||
|
||||
### Adaptiveness
|
||||
- [ ] Runtime data feedback loop
|
||||
- [ ] Instant Developer Feedback for improvement
|
||||
|
||||
### Sample Agents
|
||||
- [ ] Knowledge Agent
|
||||
@@ -106,9 +225,35 @@ timeline
|
||||
|
||||
### Agent Capability
|
||||
- [ ] Streaming mode support
|
||||
- [ ] Image Generation support
|
||||
- [ ] Take end user input Image and flatfile understand capability
|
||||
|
||||
### Cross-Platform
|
||||
- [ ] JavaScript / TypeScript Version SDK
|
||||
### Event-loop For Nodes (Opencode-style)
|
||||
- [ ] **Event bus**
|
||||
|
||||
### Memory System Iteration
|
||||
- [ ] **Message Model & Session Management**
|
||||
- [ ] Introduce `Message` class with structured content types
|
||||
- [ ] Implement `Session` classes for conversation state
|
||||
- [ ] **Storage Migration**
|
||||
- [ ] Implement granular per-message file persistence (`/message/[agentID]/...`)
|
||||
- [ ] Migrate from monolithic run storage
|
||||
- [ ] **Context Building & Conversation Loop**
|
||||
- [ ] Implement `Message.stream(sessionID)`
|
||||
- [ ] Update `LLMNode.execute()` for full context building
|
||||
- [ ] Implement `Message.toModelMessages()` conversion
|
||||
- [ ] **Proactive Compaction**
|
||||
- [ ] Implement proactive overflow detection
|
||||
- [ ] Develop backward-scanning pruning strategy (e.g., clearing old tool outputs)
|
||||
- [ ] **Enhanced Token Tracking**
|
||||
- [ ] Extend `LLMResponse` to track reasoning and cache tokens
|
||||
- [ ] Integrate granular token metrics into compaction logic
|
||||
|
||||
### Coding Agent Support
|
||||
- [ ] Claude Code
|
||||
- [ ] Cursor
|
||||
- [ ] Opencode
|
||||
- [ ] Antigravity
|
||||
|
||||
### File System Enhancement
|
||||
- [ ] Semantic Search integration
|
||||
@@ -123,7 +268,7 @@ timeline
|
||||
- [ ] Wake-up Tool (resume agent tasks)
|
||||
|
||||
### Deployment (Self-Hosted)
|
||||
- [ ] Docker container standardization
|
||||
- [ ] Workder agent docker container standardization
|
||||
- [ ] Headless backend execution
|
||||
- [ ] Exposed API for frontend attachment
|
||||
- [ ] Local monitoring & observability
|
||||
@@ -148,3 +293,7 @@ timeline
|
||||
- [ ] Analytics Agent
|
||||
- [ ] Training Agent
|
||||
- [ ] Smart Entry / Form Agent (self-evolution emphasis)
|
||||
|
||||
### Cross-Platform
|
||||
- [ ] JavaScript / TypeScript Version SDK
|
||||
- [ ] Better windows support
|
||||
|
||||
+2
-2
@@ -3,12 +3,12 @@
|
||||
"agent-builder": {
|
||||
"command": "python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "/home/timothy/oss/hive/core"
|
||||
"cwd": "core"
|
||||
},
|
||||
"tools": {
|
||||
"command": "python",
|
||||
"args": ["-m", "aden_tools.mcp_server", "--stdio"],
|
||||
"cwd": "/home/timothy/oss/hive/tools"
|
||||
"cwd": "tools"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+6
-10
@@ -132,24 +132,20 @@ runtime.end_run(success=True, narrative="Successfully processed all data")
|
||||
|
||||
The framework includes a goal-based testing framework for validating agent behavior.
|
||||
|
||||
Tests are generated using MCP tools (`generate_constraint_tests`, `generate_success_tests`) which return guidelines. Claude writes tests directly using the Write tool based on these guidelines.
|
||||
|
||||
```bash
|
||||
# Generate tests from a goal definition
|
||||
python -m framework test-generate goal.json
|
||||
|
||||
# Interactively approve generated tests
|
||||
python -m framework test-approve <goal_id>
|
||||
|
||||
# Run tests against an agent
|
||||
python -m framework test-run <agent_path> --parallel 4
|
||||
python -m framework test-run <agent_path> --goal <goal_id> --parallel 4
|
||||
|
||||
# Debug failed tests
|
||||
python -m framework test-debug <goal_id> <test_id>
|
||||
python -m framework test-debug <agent_path> <test_name>
|
||||
|
||||
# List tests by status
|
||||
# List tests for a goal
|
||||
python -m framework test-list <goal_id>
|
||||
```
|
||||
|
||||
For detailed testing workflows, see the [testing-agent skill](.claude/skills/testing-agent/SKILL.md).
|
||||
For detailed testing workflows, see the [testing-agent skill](../.claude/skills/testing-agent/SKILL.md).
|
||||
|
||||
### Analyzing Agent Behavior with Builder
|
||||
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Minimal Manual Agent Example
|
||||
----------------------------
|
||||
This example demonstrates how to build and run an agent programmatically
|
||||
without using the Claude Code CLI or external LLM APIs.
|
||||
|
||||
It uses 'function' nodes to define logic in pure Python, making it perfect
|
||||
for understanding the core runtime loop:
|
||||
Setup -> Graph definition -> Execution -> Result
|
||||
|
||||
Run with:
|
||||
PYTHONPATH=core python core/examples/manual_agent.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from framework.graph import EdgeCondition, EdgeSpec, Goal, GraphSpec, NodeSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
# 1. Define Node Logic (Pure Python Functions)
|
||||
def greet(name: str) -> str:
|
||||
"""Generate a simple greeting."""
|
||||
return f"Hello, {name}!"
|
||||
|
||||
|
||||
def uppercase(greeting: str) -> str:
|
||||
"""Convert text to uppercase."""
|
||||
return greeting.upper()
|
||||
|
||||
|
||||
async def main():
|
||||
print("🚀 Setting up Manual Agent...")
|
||||
|
||||
# 2. Define the Goal
|
||||
# Every agent needs a goal with success criteria
|
||||
goal = Goal(
|
||||
id="greet-user",
|
||||
name="Greet User",
|
||||
description="Generate a friendly uppercase greeting",
|
||||
success_criteria=[
|
||||
{
|
||||
"id": "greeting_generated",
|
||||
"description": "Greeting produced",
|
||||
"metric": "custom",
|
||||
"target": "any",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# 3. Define Nodes
|
||||
# Nodes describe steps in the process
|
||||
node1 = NodeSpec(
|
||||
id="greeter",
|
||||
name="Greeter",
|
||||
description="Generates a simple greeting",
|
||||
node_type="function",
|
||||
function="greet", # Matches the registered function name
|
||||
input_keys=["name"],
|
||||
output_keys=["greeting"],
|
||||
)
|
||||
|
||||
node2 = NodeSpec(
|
||||
id="uppercaser",
|
||||
name="Uppercaser",
|
||||
description="Converts greeting to uppercase",
|
||||
node_type="function",
|
||||
function="uppercase",
|
||||
input_keys=["greeting"],
|
||||
output_keys=["final_greeting"],
|
||||
)
|
||||
|
||||
# 4. Define Edges
|
||||
# Edges define the flow between nodes
|
||||
edge1 = EdgeSpec(
|
||||
id="greet-to-upper",
|
||||
source="greeter",
|
||||
target="uppercaser",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
)
|
||||
|
||||
# 5. Create Graph
|
||||
# The graph works like a blueprint connecting nodes and edges
|
||||
graph = GraphSpec(
|
||||
id="greeting-agent",
|
||||
goal_id="greet-user",
|
||||
entry_node="greeter",
|
||||
terminal_nodes=["uppercaser"],
|
||||
nodes=[node1, node2],
|
||||
edges=[edge1],
|
||||
)
|
||||
|
||||
# 6. Initialize Runtime & Executor
|
||||
# Runtime handles state/memory; Executor runs the graph
|
||||
from pathlib import Path
|
||||
|
||||
runtime = Runtime(storage_path=Path("./agent_logs"))
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
|
||||
# 7. Register Function Implementations
|
||||
# Connect string names in NodeSpecs to actual Python functions
|
||||
executor.register_function("greeter", greet)
|
||||
executor.register_function("uppercaser", uppercase)
|
||||
|
||||
# 8. Execute Agent
|
||||
print("▶ Executing agent with input: name='Alice'...")
|
||||
|
||||
result = await executor.execute(graph=graph, goal=goal, input_data={"name": "Alice"})
|
||||
|
||||
# 9. Verify Results
|
||||
if result.success:
|
||||
print("\n✅ Success!")
|
||||
print(f"Path taken: {' -> '.join(result.path)}")
|
||||
print(f"Final output: {result.output.get('final_greeting')}")
|
||||
else:
|
||||
print(f"\n❌ Failed: {result.error}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Optional: Enable logging to see internal decision flow
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
asyncio.run(main())
|
||||
@@ -37,9 +37,9 @@ async def example_1_programmatic_registration():
|
||||
print(f"\nAvailable tools: {list(tools.keys())}")
|
||||
|
||||
# Run the agent with MCP tools available
|
||||
result = await runner.run({
|
||||
"objective": "Search for 'Claude AI' and summarize the top 3 results"
|
||||
})
|
||||
result = await runner.run(
|
||||
{"objective": "Search for 'Claude AI' and summarize the top 3 results"}
|
||||
)
|
||||
|
||||
print(f"\nAgent result: {result}")
|
||||
|
||||
@@ -79,10 +79,7 @@ async def example_3_config_file():
|
||||
# Copy example config (in practice, you'd place this in your agent folder)
|
||||
import shutil
|
||||
|
||||
shutil.copy(
|
||||
"examples/mcp_servers.json",
|
||||
test_agent_path / "mcp_servers.json"
|
||||
)
|
||||
shutil.copy("examples/mcp_servers.json", test_agent_path / "mcp_servers.json")
|
||||
|
||||
# Load agent - MCP servers will be auto-discovered
|
||||
runner = AgentRunner.load(test_agent_path)
|
||||
@@ -102,10 +99,10 @@ async def example_4_custom_agent_with_mcp_tools():
|
||||
"""Example 4: Build custom agent that uses MCP tools"""
|
||||
print("\n=== Example 4: Custom Agent with MCP Tools ===\n")
|
||||
|
||||
from framework.builder.workflow import WorkflowBuilder
|
||||
from framework.builder.workflow import GraphBuilder
|
||||
|
||||
# Create a workflow builder
|
||||
builder = WorkflowBuilder()
|
||||
builder = GraphBuilder()
|
||||
|
||||
# Define goal
|
||||
builder.set_goal(
|
||||
@@ -115,7 +112,9 @@ async def example_4_custom_agent_with_mcp_tools():
|
||||
)
|
||||
|
||||
# Add success criteria
|
||||
builder.add_success_criterion("search-results", "Successfully retrieve at least 3 web search results")
|
||||
builder.add_success_criterion(
|
||||
"search-results", "Successfully retrieve at least 3 web search results"
|
||||
)
|
||||
builder.add_success_criterion("summary", "Provide a clear, concise summary of the findings")
|
||||
|
||||
# Add nodes that will use MCP tools
|
||||
|
||||
@@ -32,10 +32,8 @@ from framework.schemas.run import Problem, Run, RunSummary
|
||||
# Testing framework
|
||||
from framework.testing import (
|
||||
ApprovalStatus,
|
||||
ConstraintTestGenerator,
|
||||
DebugTool,
|
||||
ErrorCategory,
|
||||
SuccessCriteriaTestGenerator,
|
||||
Test,
|
||||
TestResult,
|
||||
TestStorage,
|
||||
@@ -68,7 +66,5 @@ __all__ = [
|
||||
"TestStorage",
|
||||
"ApprovalStatus",
|
||||
"ErrorCategory",
|
||||
"ConstraintTestGenerator",
|
||||
"SuccessCriteriaTestGenerator",
|
||||
"DebugTool",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Allow running as python -m framework"""
|
||||
"""Allow running as ``python -m framework``, which powers the ``hive`` console entry point."""
|
||||
|
||||
from framework.cli import main
|
||||
|
||||
|
||||
@@ -327,7 +327,8 @@ class BuilderQuery:
|
||||
"target": node_id,
|
||||
"reason": f"Node has {failure_rate:.1%} failure rate",
|
||||
"recommendation": (
|
||||
f"Review and improve node '{node_id}' - high failure rate suggests prompt or tool issues"
|
||||
f"Review and improve node '{node_id}' - "
|
||||
"high failure rate suggests prompt or tool issues"
|
||||
),
|
||||
"priority": "high" if failure_rate > 0.3 else "medium",
|
||||
}
|
||||
@@ -353,7 +354,9 @@ class BuilderQuery:
|
||||
"type": "architecture",
|
||||
"target": goal_id,
|
||||
"reason": f"Goal success rate is only {patterns.success_rate:.1%}",
|
||||
"recommendation": ("Consider restructuring the agent graph or improving goal definition"),
|
||||
"recommendation": (
|
||||
"Consider restructuring the agent graph or improving goal definition"
|
||||
),
|
||||
"priority": "high",
|
||||
}
|
||||
)
|
||||
@@ -410,12 +413,15 @@ class BuilderQuery:
|
||||
if alternatives:
|
||||
alt_desc = alternatives[0].description
|
||||
chosen_desc = chosen.description if chosen else "unknown"
|
||||
suggestions.append(f"Consider alternative: '{alt_desc}' instead of '{chosen_desc}'")
|
||||
suggestions.append(
|
||||
f"Consider alternative: '{alt_desc}' instead of '{chosen_desc}'"
|
||||
)
|
||||
|
||||
# Check for missing context
|
||||
if not decision.input_context:
|
||||
suggestions.append(
|
||||
f"Decision '{decision.intent}' had no input context - ensure relevant data is passed"
|
||||
f"Decision '{decision.intent}' had no input context - "
|
||||
"ensure relevant data is passed"
|
||||
)
|
||||
|
||||
# Check for constraint issues
|
||||
@@ -476,7 +482,8 @@ class BuilderQuery:
|
||||
for i, (d1, d2) in enumerate(zip(run1.decisions, run2.decisions, strict=False)):
|
||||
if d1.chosen_option_id != d2.chosen_option_id:
|
||||
differences.append(
|
||||
f"Diverged at decision {i}: chose '{d1.chosen_option_id}' vs '{d2.chosen_option_id}'"
|
||||
f"Diverged at decision {i}: "
|
||||
f"chose '{d1.chosen_option_id}' vs '{d2.chosen_option_id}'"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
@@ -762,7 +762,9 @@ class GraphBuilder:
|
||||
"tests": len(self.session.test_cases),
|
||||
"tests_passed": sum(1 for t in self.session.test_results if t.passed),
|
||||
"approvals": len(self.session.approvals),
|
||||
"pending_validation": (self._pending_validation.model_dump() if self._pending_validation else None),
|
||||
"pending_validation": self._pending_validation.model_dump()
|
||||
if self._pending_validation
|
||||
else None,
|
||||
}
|
||||
|
||||
def show(self) -> str:
|
||||
|
||||
+48
-15
@@ -1,29 +1,62 @@
|
||||
"""
|
||||
Command-line interface for Goal Agent.
|
||||
Command-line interface for Aden Hive.
|
||||
|
||||
Usage:
|
||||
python -m core run exports/my-agent --input '{"key": "value"}'
|
||||
python -m core info exports/my-agent
|
||||
python -m core validate exports/my-agent
|
||||
python -m core list exports/
|
||||
python -m core dispatch exports/ --input '{"key": "value"}'
|
||||
python -m core shell exports/my-agent
|
||||
hive run exports/my-agent --input '{"key": "value"}'
|
||||
hive info exports/my-agent
|
||||
hive validate exports/my-agent
|
||||
hive list exports/
|
||||
hive dispatch exports/ --input '{"key": "value"}'
|
||||
hive shell exports/my-agent
|
||||
|
||||
Testing commands:
|
||||
python -m core test-generate goal.json
|
||||
python -m core test-approve <goal_id>
|
||||
python -m core test-run <agent_path> --goal <goal_id>
|
||||
python -m core test-debug <goal_id> <test_id>
|
||||
python -m core test-list <goal_id>
|
||||
python -m core test-stats <goal_id>
|
||||
hive test-run <agent_path> --goal <goal_id>
|
||||
hive test-debug <goal_id> <test_id>
|
||||
hive test-list <goal_id>
|
||||
hive test-stats <goal_id>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _configure_paths():
|
||||
"""Auto-configure sys.path so agents in exports/ are discoverable.
|
||||
|
||||
Resolves the project root by walking up from this file (framework/cli.py lives
|
||||
inside core/framework/) or from CWD, then adds the exports/ directory to sys.path
|
||||
if it exists. This eliminates the need for manual PYTHONPATH configuration.
|
||||
"""
|
||||
# Strategy 1: resolve relative to this file (works when installed via pip install -e core/)
|
||||
framework_dir = Path(__file__).resolve().parent # core/framework/
|
||||
core_dir = framework_dir.parent # core/
|
||||
project_root = core_dir.parent # project root
|
||||
|
||||
# Strategy 2: if project_root doesn't look right, fall back to CWD
|
||||
if not (project_root / "exports").is_dir() and not (project_root / "core").is_dir():
|
||||
project_root = Path.cwd()
|
||||
|
||||
# Add exports/ to sys.path so agents are importable as top-level packages
|
||||
exports_dir = project_root / "exports"
|
||||
if exports_dir.is_dir():
|
||||
exports_str = str(exports_dir)
|
||||
if exports_str not in sys.path:
|
||||
sys.path.insert(0, exports_str)
|
||||
|
||||
# Ensure core/ is also in sys.path (for non-editable-install scenarios)
|
||||
core_str = str(project_root / "core")
|
||||
if (project_root / "core").is_dir() and core_str not in sys.path:
|
||||
sys.path.insert(0, core_str)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Goal Agent - Build and run goal-driven agents")
|
||||
_configure_paths()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="hive",
|
||||
description="Aden Hive - Build and run goal-driven agents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="claude-haiku-4-5-20251001",
|
||||
@@ -37,7 +70,7 @@ def main():
|
||||
|
||||
register_commands(subparsers)
|
||||
|
||||
# Register testing commands (test-generate, test-approve, test-run, test-debug, etc.)
|
||||
# Register testing commands (test-run, test-debug, test-list, test-stats)
|
||||
from framework.testing.cli import register_testing_commands
|
||||
|
||||
register_testing_commands(subparsers)
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Credential Store - Production-ready credential management for Hive.
|
||||
|
||||
This module provides secure credential storage with:
|
||||
- Key-vault structure: Credentials as objects with multiple keys
|
||||
- Template-based usage: {{cred.key}} patterns for injection
|
||||
- Bipartisan model: Store stores values, tools define usage
|
||||
- Provider system: Extensible lifecycle management (refresh, validate)
|
||||
- Multiple backends: Encrypted files, env vars, HashiCorp Vault
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore, CredentialObject
|
||||
|
||||
# Create store with encrypted storage
|
||||
store = CredentialStore.with_encrypted_storage() # defaults to ~/.hive/credentials
|
||||
|
||||
# Get a credential
|
||||
api_key = store.get("brave_search")
|
||||
|
||||
# Resolve templates in headers
|
||||
headers = store.resolve_headers({
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}"
|
||||
})
|
||||
|
||||
# Save a new credential
|
||||
store.save_credential(CredentialObject(
|
||||
id="my_api",
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("xxx"))}
|
||||
))
|
||||
|
||||
For OAuth2 support:
|
||||
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config
|
||||
|
||||
For Aden server sync:
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
For Vault integration:
|
||||
from core.framework.credentials.vault import HashiCorpVaultStorage
|
||||
"""
|
||||
|
||||
from .models import (
|
||||
CredentialDecryptionError,
|
||||
CredentialError,
|
||||
CredentialKey,
|
||||
CredentialKeyNotFoundError,
|
||||
CredentialNotFoundError,
|
||||
CredentialObject,
|
||||
CredentialRefreshError,
|
||||
CredentialType,
|
||||
CredentialUsageSpec,
|
||||
CredentialValidationError,
|
||||
)
|
||||
from .provider import (
|
||||
BearerTokenProvider,
|
||||
CredentialProvider,
|
||||
StaticProvider,
|
||||
)
|
||||
from .storage import (
|
||||
CompositeStorage,
|
||||
CredentialStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
InMemoryStorage,
|
||||
)
|
||||
from .store import CredentialStore
|
||||
from .template import TemplateResolver
|
||||
|
||||
# Aden sync components (lazy import to avoid httpx dependency when not needed)
|
||||
# Usage: from core.framework.credentials.aden import AdenSyncProvider
|
||||
# Or: from core.framework.credentials import AdenSyncProvider
|
||||
try:
|
||||
from .aden import (
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
_ADEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
_ADEN_AVAILABLE = False
|
||||
|
||||
__all__ = [
|
||||
# Main store
|
||||
"CredentialStore",
|
||||
# Models
|
||||
"CredentialObject",
|
||||
"CredentialKey",
|
||||
"CredentialType",
|
||||
"CredentialUsageSpec",
|
||||
# Providers
|
||||
"CredentialProvider",
|
||||
"StaticProvider",
|
||||
"BearerTokenProvider",
|
||||
# Storage backends
|
||||
"CredentialStorage",
|
||||
"EncryptedFileStorage",
|
||||
"EnvVarStorage",
|
||||
"InMemoryStorage",
|
||||
"CompositeStorage",
|
||||
# Template resolution
|
||||
"TemplateResolver",
|
||||
# Exceptions
|
||||
"CredentialError",
|
||||
"CredentialNotFoundError",
|
||||
"CredentialKeyNotFoundError",
|
||||
"CredentialRefreshError",
|
||||
"CredentialValidationError",
|
||||
"CredentialDecryptionError",
|
||||
# Aden sync (optional - requires httpx)
|
||||
"AdenSyncProvider",
|
||||
"AdenCredentialClient",
|
||||
"AdenClientConfig",
|
||||
"AdenCachedStorage",
|
||||
]
|
||||
|
||||
# Track Aden availability for runtime checks
|
||||
ADEN_AVAILABLE = _ADEN_AVAILABLE
|
||||
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Aden Credential Sync.
|
||||
|
||||
Components for synchronizing credentials with the Aden authentication server.
|
||||
|
||||
The Aden server handles OAuth2 authorization flows and maintains refresh tokens.
|
||||
These components fetch and cache access tokens locally while delegating
|
||||
lifecycle management to Aden.
|
||||
|
||||
Components:
|
||||
- AdenCredentialClient: HTTP client for Aden API
|
||||
- AdenSyncProvider: CredentialProvider that syncs with Aden
|
||||
- AdenCachedStorage: Storage with local cache + Aden fallback
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.storage import EncryptedFileStorage
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# Configure (API key loaded from ADEN_API_KEY env var)
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url=os.environ["ADEN_API_URL"],
|
||||
))
|
||||
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage(),
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Initial sync
|
||||
provider.sync_all(store)
|
||||
|
||||
# Use normally
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
|
||||
See docs/aden-credential-sync.md for detailed documentation.
|
||||
"""
|
||||
|
||||
from .client import (
|
||||
AdenAuthenticationError,
|
||||
AdenClientConfig,
|
||||
AdenClientError,
|
||||
AdenCredentialClient,
|
||||
AdenCredentialResponse,
|
||||
AdenIntegrationInfo,
|
||||
AdenNotFoundError,
|
||||
AdenRateLimitError,
|
||||
AdenRefreshError,
|
||||
)
|
||||
from .provider import AdenSyncProvider
|
||||
from .storage import AdenCachedStorage
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"AdenCredentialClient",
|
||||
"AdenClientConfig",
|
||||
"AdenCredentialResponse",
|
||||
"AdenIntegrationInfo",
|
||||
# Client errors
|
||||
"AdenClientError",
|
||||
"AdenAuthenticationError",
|
||||
"AdenNotFoundError",
|
||||
"AdenRateLimitError",
|
||||
"AdenRefreshError",
|
||||
# Provider
|
||||
"AdenSyncProvider",
|
||||
# Storage
|
||||
"AdenCachedStorage",
|
||||
]
|
||||
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Aden Credential Client.
|
||||
|
||||
HTTP client for communicating with the Aden authentication server.
|
||||
The Aden server handles OAuth2 authorization flows and token management.
|
||||
This client fetches tokens and delegates refresh operations to Aden.
|
||||
|
||||
Usage:
|
||||
# API key loaded from ADEN_API_KEY environment variable by default
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url="https://api.adenhq.com",
|
||||
))
|
||||
|
||||
# Or explicitly provide the API key
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url="https://api.adenhq.com",
|
||||
api_key="your-api-key",
|
||||
))
|
||||
|
||||
# Fetch a credential
|
||||
response = client.get_credential("hubspot")
|
||||
if response:
|
||||
print(f"Token expires at: {response.expires_at}")
|
||||
|
||||
# Request a refresh
|
||||
refreshed = client.request_refresh("hubspot")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdenClientError(Exception):
|
||||
"""Base exception for Aden client errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdenAuthenticationError(AdenClientError):
|
||||
"""Raised when API key is invalid or revoked."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdenNotFoundError(AdenClientError):
|
||||
"""Raised when integration is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdenRefreshError(AdenClientError):
|
||||
"""Raised when token refresh fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
requires_reauthorization: bool = False,
|
||||
reauthorization_url: str | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.requires_reauthorization = requires_reauthorization
|
||||
self.reauthorization_url = reauthorization_url
|
||||
|
||||
|
||||
class AdenRateLimitError(AdenClientError):
|
||||
"""Raised when rate limited."""
|
||||
|
||||
def __init__(self, message: str, retry_after: int = 60):
|
||||
super().__init__(message)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdenClientConfig:
|
||||
"""Configuration for Aden API client."""
|
||||
|
||||
base_url: str
|
||||
"""Base URL of the Aden server (e.g., 'https://api.adenhq.com')."""
|
||||
|
||||
api_key: str | None = None
|
||||
"""Agent's API key for authenticating with Aden.
|
||||
If not provided, loaded from ADEN_API_KEY environment variable."""
|
||||
|
||||
tenant_id: str | None = None
|
||||
"""Optional tenant ID for multi-tenant deployments."""
|
||||
|
||||
timeout: float = 30.0
|
||||
"""Request timeout in seconds."""
|
||||
|
||||
retry_attempts: int = 3
|
||||
"""Number of retry attempts for transient failures."""
|
||||
|
||||
retry_delay: float = 1.0
|
||||
"""Base delay between retries in seconds (exponential backoff)."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Load API key from environment if not provided."""
|
||||
if self.api_key is None:
|
||||
self.api_key = os.environ.get("ADEN_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Aden API key not provided. Either pass api_key to AdenClientConfig "
|
||||
"or set the ADEN_API_KEY environment variable."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdenCredentialResponse:
|
||||
"""Response from Aden server containing credential data."""
|
||||
|
||||
integration_id: str
|
||||
"""Unique identifier for the integration (e.g., 'hubspot')."""
|
||||
|
||||
integration_type: str
|
||||
"""Type of integration (e.g., 'hubspot', 'github', 'slack')."""
|
||||
|
||||
access_token: str
|
||||
"""The access token for API calls."""
|
||||
|
||||
token_type: str = "Bearer"
|
||||
"""Token type (usually 'Bearer')."""
|
||||
|
||||
expires_at: datetime | None = None
|
||||
"""When the access token expires (UTC)."""
|
||||
|
||||
scopes: list[str] = field(default_factory=list)
|
||||
"""OAuth2 scopes granted to this token."""
|
||||
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional integration-specific metadata."""
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict[str, Any], integration_id: str | None = None
|
||||
) -> AdenCredentialResponse:
|
||||
"""Create from API response dictionary."""
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
|
||||
|
||||
return cls(
|
||||
integration_id=integration_id or data.get("alias", data.get("provider", "")),
|
||||
integration_type=data.get("provider", ""),
|
||||
access_token=data["access_token"],
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_at=expires_at,
|
||||
scopes=data.get("scopes", []),
|
||||
metadata={"email": data.get("email")} if data.get("email") else {},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdenIntegrationInfo:
|
||||
"""Information about an available integration."""
|
||||
|
||||
integration_id: str
|
||||
integration_type: str
|
||||
status: str # "active", "requires_reauth", "expired"
|
||||
expires_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> AdenIntegrationInfo:
|
||||
"""Create from API response dictionary."""
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
|
||||
|
||||
return cls(
|
||||
integration_id=data["integration_id"],
|
||||
integration_type=data.get("provider", data["integration_id"]),
|
||||
status=data.get("status", "unknown"),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
||||
class AdenCredentialClient:
|
||||
"""
|
||||
HTTP client for Aden credential server.
|
||||
|
||||
Handles communication with the Aden authentication server,
|
||||
including fetching credentials, requesting refreshes, and
|
||||
reporting usage statistics.
|
||||
|
||||
The client automatically handles:
|
||||
- Retries with exponential backoff for transient failures
|
||||
- Proper error classification (auth, not found, rate limit, etc.)
|
||||
- Request headers for authentication and tenant isolation
|
||||
|
||||
Usage:
|
||||
# API key loaded from ADEN_API_KEY environment variable
|
||||
config = AdenClientConfig(
|
||||
base_url="https://api.adenhq.com",
|
||||
)
|
||||
|
||||
client = AdenCredentialClient(config)
|
||||
|
||||
# Fetch a credential
|
||||
cred = client.get_credential("hubspot")
|
||||
if cred:
|
||||
headers = {"Authorization": f"Bearer {cred.access_token}"}
|
||||
|
||||
# List all integrations
|
||||
integrations = client.list_integrations()
|
||||
for info in integrations:
|
||||
print(f"{info.integration_id}: {info.status}")
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
"""
|
||||
|
||||
def __init__(self, config: AdenClientConfig):
|
||||
"""
|
||||
Initialize the Aden client.
|
||||
|
||||
Args:
|
||||
config: Client configuration including base URL and API key.
|
||||
"""
|
||||
self.config = config
|
||||
self._client: httpx.Client | None = None
|
||||
|
||||
def _get_client(self) -> httpx.Client:
|
||||
"""Get or create the HTTP client."""
|
||||
if self._client is None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "hive-credential-store/1.0",
|
||||
}
|
||||
|
||||
if self.config.tenant_id:
|
||||
headers["X-Tenant-ID"] = self.config.tenant_id
|
||||
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
timeout=self.config.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return self._client
|
||||
|
||||
def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> httpx.Response:
|
||||
"""Make a request with retry logic."""
|
||||
client = self._get_client()
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(self.config.retry_attempts):
|
||||
try:
|
||||
response = client.request(method, path, **kwargs)
|
||||
|
||||
# Handle specific error codes
|
||||
if response.status_code == 401:
|
||||
raise AdenAuthenticationError("Agent API key is invalid or revoked")
|
||||
|
||||
if response.status_code == 404:
|
||||
raise AdenNotFoundError(f"Integration not found: {path}")
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = int(response.headers.get("Retry-After", 60))
|
||||
raise AdenRateLimitError(
|
||||
"Rate limited by Aden server",
|
||||
retry_after=retry_after,
|
||||
)
|
||||
|
||||
if response.status_code == 400:
|
||||
data = response.json()
|
||||
if data.get("error") == "refresh_failed":
|
||||
raise AdenRefreshError(
|
||||
data.get("message", "Token refresh failed"),
|
||||
requires_reauthorization=data.get("requires_reauthorization", False),
|
||||
reauthorization_url=data.get("reauthorization_url"),
|
||||
)
|
||||
|
||||
# Success or other error
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
last_error = e
|
||||
if attempt < self.config.retry_attempts - 1:
|
||||
delay = self.config.retry_delay * (2**attempt)
|
||||
logger.warning(
|
||||
f"Aden request failed (attempt {attempt + 1}), retrying in {delay}s: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise AdenClientError(f"Failed to connect to Aden server: {e}") from e
|
||||
|
||||
except (
|
||||
AdenAuthenticationError,
|
||||
AdenNotFoundError,
|
||||
AdenRefreshError,
|
||||
AdenRateLimitError,
|
||||
):
|
||||
# Don't retry these errors
|
||||
raise
|
||||
|
||||
# Should not reach here, but just in case
|
||||
raise AdenClientError(
|
||||
f"Request failed after {self.config.retry_attempts} attempts"
|
||||
) from last_error
|
||||
|
||||
def get_credential(self, integration_id: str) -> AdenCredentialResponse | None:
|
||||
"""
|
||||
Fetch the current credential for an integration.
|
||||
|
||||
The Aden server may refresh the token internally if it's expired
|
||||
before returning it.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier (e.g., 'hubspot').
|
||||
|
||||
Returns:
|
||||
Credential response with access token, or None if not found.
|
||||
|
||||
Raises:
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
AdenClientError: For connection failures.
|
||||
"""
|
||||
try:
|
||||
response = self._request_with_retry("GET", f"/v1/credentials/{integration_id}")
|
||||
data = response.json()
|
||||
return AdenCredentialResponse.from_dict(data, integration_id=integration_id)
|
||||
except AdenNotFoundError:
|
||||
return None
|
||||
|
||||
def request_refresh(self, integration_id: str) -> AdenCredentialResponse:
|
||||
"""
|
||||
Request the Aden server to refresh the token.
|
||||
|
||||
Use this when the local store detects an expired or near-expiry token.
|
||||
The Aden server handles the actual OAuth2 refresh token flow.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier.
|
||||
|
||||
Returns:
|
||||
Credential response with new access token.
|
||||
|
||||
Raises:
|
||||
AdenRefreshError: If refresh fails (may require re-authorization).
|
||||
AdenNotFoundError: If integration not found.
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
AdenRateLimitError: If rate limited.
|
||||
"""
|
||||
response = self._request_with_retry("POST", f"/v1/credentials/{integration_id}/refresh")
|
||||
data = response.json()
|
||||
return AdenCredentialResponse.from_dict(data, integration_id=integration_id)
|
||||
|
||||
def list_integrations(self) -> list[AdenIntegrationInfo]:
|
||||
"""
|
||||
List all integrations available for this agent/tenant.
|
||||
|
||||
Returns:
|
||||
List of integration info objects.
|
||||
|
||||
Raises:
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
AdenClientError: For connection failures.
|
||||
"""
|
||||
response = self._request_with_retry("GET", "/v1/credentials")
|
||||
data = response.json()
|
||||
return [AdenIntegrationInfo.from_dict(item) for item in data.get("integrations", [])]
|
||||
|
||||
def validate_token(self, integration_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Check if a token is still valid without fetching it.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier.
|
||||
|
||||
Returns:
|
||||
Dict with 'valid' bool and optional 'expires_at', 'reason',
|
||||
'requires_reauthorization', 'reauthorization_url'.
|
||||
|
||||
Raises:
|
||||
AdenNotFoundError: If integration not found.
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
"""
|
||||
response = self._request_with_retry("GET", f"/v1/credentials/{integration_id}/validate")
|
||||
return response.json()
|
||||
|
||||
def report_usage(
|
||||
self,
|
||||
integration_id: str,
|
||||
operation: str,
|
||||
status: str = "success",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Report credential usage statistics to Aden.
|
||||
|
||||
This is optional and used for analytics/billing.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier.
|
||||
operation: Operation name (e.g., 'api_call').
|
||||
status: Operation status ('success', 'error').
|
||||
metadata: Additional operation metadata.
|
||||
"""
|
||||
try:
|
||||
self._request_with_retry(
|
||||
"POST",
|
||||
f"/v1/credentials/{integration_id}/usage",
|
||||
json={
|
||||
"operation": operation,
|
||||
"status": status,
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"metadata": metadata or {},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# Usage reporting is best-effort, don't fail on errors
|
||||
logger.warning(f"Failed to report usage for '{integration_id}': {e}")
|
||||
|
||||
def health_check(self) -> dict[str, Any]:
|
||||
"""
|
||||
Check Aden server health and connectivity.
|
||||
|
||||
Returns:
|
||||
Dict with 'status', 'version', 'timestamp', and optionally 'error'.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = client.get("/health")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
data["latency_ms"] = response.elapsed.total_seconds() * 1000
|
||||
return data
|
||||
return {
|
||||
"status": "degraded",
|
||||
"error": f"Unexpected status code: {response.status_code}",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the HTTP client and release resources."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def __enter__(self) -> AdenCredentialClient:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
@@ -0,0 +1,415 @@
|
||||
"""
|
||||
Aden Sync Provider.
|
||||
|
||||
Provider that synchronizes credentials with the Aden authentication server.
|
||||
The Aden server is the authoritative source for OAuth2 tokens - this provider
|
||||
fetches and caches tokens locally while delegating refresh operations to Aden.
|
||||
|
||||
Usage:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.storage import EncryptedFileStorage
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# Configure client (API key loaded from ADEN_API_KEY env var)
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url=os.environ["ADEN_API_URL"],
|
||||
))
|
||||
|
||||
# Create provider
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
# Create store
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage(),
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Initial sync from Aden
|
||||
provider.sync_all(store)
|
||||
|
||||
# Use normally - auto-refreshes via Aden when needed
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from ..models import CredentialKey, CredentialObject, CredentialRefreshError, CredentialType
|
||||
from ..provider import CredentialProvider
|
||||
from .client import (
|
||||
AdenClientError,
|
||||
AdenCredentialClient,
|
||||
AdenCredentialResponse,
|
||||
AdenRefreshError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..store import CredentialStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdenSyncProvider(CredentialProvider):
|
||||
"""
|
||||
Provider that synchronizes credentials with the Aden server.
|
||||
|
||||
The Aden server handles OAuth2 authorization flows and maintains
|
||||
refresh tokens. This provider:
|
||||
|
||||
- Fetches access tokens from the Aden server
|
||||
- Delegates token refresh to the Aden server
|
||||
- Caches tokens locally in the credential store
|
||||
- Optionally reports usage statistics back to Aden
|
||||
|
||||
Key benefits:
|
||||
- Client secrets never leave the Aden server
|
||||
- Refresh token security (stored only on Aden)
|
||||
- Centralized audit logging
|
||||
- Multi-tenant support
|
||||
|
||||
Usage:
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url="https://api.adenhq.com",
|
||||
api_key=os.environ["ADEN_API_KEY"],
|
||||
))
|
||||
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage(),
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AdenCredentialClient,
|
||||
provider_id: str = "aden_sync",
|
||||
refresh_buffer_minutes: int = 5,
|
||||
report_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Aden sync provider.
|
||||
|
||||
Args:
|
||||
client: Configured Aden API client.
|
||||
provider_id: Unique identifier for this provider instance.
|
||||
Useful for multi-tenant scenarios (e.g., 'aden_tenant_123').
|
||||
refresh_buffer_minutes: Minutes before expiry to trigger refresh.
|
||||
Default is 5 minutes.
|
||||
report_usage: Whether to report usage statistics to Aden server.
|
||||
"""
|
||||
self._client = client
|
||||
self._provider_id = provider_id
|
||||
self._refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
|
||||
self._report_usage = report_usage
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
"""Unique identifier for this provider."""
|
||||
return self._provider_id
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
"""Credential types this provider can manage."""
|
||||
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
|
||||
|
||||
def can_handle(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if this provider can handle a credential.
|
||||
|
||||
Returns True if:
|
||||
- Credential type is supported (OAUTH2 or BEARER_TOKEN)
|
||||
- Credential's provider_id matches this provider, OR
|
||||
- Credential has '_aden_managed' metadata flag
|
||||
"""
|
||||
if credential.credential_type not in self.supported_types:
|
||||
return False
|
||||
|
||||
# Check if credential is explicitly linked to this provider
|
||||
if credential.provider_id == self.provider_id:
|
||||
return True
|
||||
|
||||
# Check for Aden-managed flag in metadata
|
||||
aden_flag = credential.keys.get("_aden_managed")
|
||||
if aden_flag and aden_flag.value.get_secret_value() == "true":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Refresh credential by requesting new token from Aden server.
|
||||
|
||||
The Aden server handles the actual OAuth2 refresh token flow.
|
||||
This method simply fetches the result.
|
||||
|
||||
Args:
|
||||
credential: The credential to refresh.
|
||||
|
||||
Returns:
|
||||
Updated credential with new access token.
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails.
|
||||
"""
|
||||
try:
|
||||
# Request Aden to refresh the token
|
||||
aden_response = self._client.request_refresh(credential.id)
|
||||
|
||||
# Update credential with new values
|
||||
credential = self._update_credential_from_aden(credential, aden_response)
|
||||
|
||||
logger.info(f"Refreshed credential '{credential.id}' via Aden server")
|
||||
|
||||
# Report usage if enabled
|
||||
if self._report_usage:
|
||||
self._client.report_usage(
|
||||
integration_id=credential.id,
|
||||
operation="token_refresh",
|
||||
status="success",
|
||||
)
|
||||
|
||||
return credential
|
||||
|
||||
except AdenRefreshError as e:
|
||||
logger.error(f"Aden refresh failed for '{credential.id}': {e}")
|
||||
|
||||
if e.requires_reauthorization:
|
||||
raise CredentialRefreshError(
|
||||
f"Integration '{credential.id}' requires re-authorization. "
|
||||
f"Visit: {e.reauthorization_url or 'your Aden dashboard'}"
|
||||
) from e
|
||||
|
||||
raise CredentialRefreshError(
|
||||
f"Failed to refresh credential '{credential.id}': {e}"
|
||||
) from e
|
||||
|
||||
except AdenClientError as e:
|
||||
logger.error(f"Aden client error for '{credential.id}': {e}")
|
||||
|
||||
# Check if local token is still valid
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key and access_key.expires_at:
|
||||
if datetime.now(UTC) < access_key.expires_at:
|
||||
logger.warning(f"Aden unavailable, using cached token for '{credential.id}'")
|
||||
return credential
|
||||
|
||||
raise CredentialRefreshError(
|
||||
f"Aden server unavailable and token expired for '{credential.id}'"
|
||||
) from e
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate credential via Aden server introspection.
|
||||
|
||||
Args:
|
||||
credential: The credential to validate.
|
||||
|
||||
Returns:
|
||||
True if credential is valid.
|
||||
"""
|
||||
try:
|
||||
result = self._client.validate_token(credential.id)
|
||||
return result.get("valid", False)
|
||||
except AdenClientError:
|
||||
# Fall back to local validation
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
if access_key.expires_at is None:
|
||||
# No expiration - assume valid
|
||||
return True
|
||||
|
||||
return datetime.now(UTC) < access_key.expires_at
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if credential should be refreshed.
|
||||
|
||||
Returns True if access_token is expired or within the refresh buffer.
|
||||
|
||||
Args:
|
||||
credential: The credential to check.
|
||||
|
||||
Returns:
|
||||
True if credential should be refreshed.
|
||||
"""
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
if access_key.expires_at is None:
|
||||
return False
|
||||
|
||||
# Refresh if within buffer of expiration
|
||||
return datetime.now(UTC) >= (access_key.expires_at - self._refresh_buffer)
|
||||
|
||||
def fetch_from_aden(self, integration_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Fetch credential directly from Aden server.
|
||||
|
||||
Use this for initial population or when local cache is missing.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier (e.g., 'hubspot').
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
|
||||
Raises:
|
||||
AdenClientError: For connection failures.
|
||||
"""
|
||||
aden_response = self._client.get_credential(integration_id)
|
||||
if aden_response is None:
|
||||
return None
|
||||
|
||||
return self._aden_response_to_credential(aden_response)
|
||||
|
||||
def sync_all(self, store: CredentialStore) -> int:
|
||||
"""
|
||||
Sync all credentials from Aden server to local store.
|
||||
|
||||
Fetches the list of available integrations from Aden and
|
||||
populates the local credential store with current tokens.
|
||||
|
||||
Args:
|
||||
store: The credential store to populate.
|
||||
|
||||
Returns:
|
||||
Number of credentials synced.
|
||||
"""
|
||||
synced = 0
|
||||
|
||||
try:
|
||||
integrations = self._client.list_integrations()
|
||||
|
||||
for info in integrations:
|
||||
if info.status != "active":
|
||||
logger.warning(
|
||||
f"Skipping integration '{info.integration_id}': status={info.status}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
cred = self.fetch_from_aden(info.integration_id)
|
||||
if cred:
|
||||
store.save_credential(cred)
|
||||
synced += 1
|
||||
logger.info(f"Synced credential '{info.integration_id}' from Aden")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync '{info.integration_id}': {e}")
|
||||
|
||||
except AdenClientError as e:
|
||||
logger.error(f"Failed to list integrations from Aden: {e}")
|
||||
|
||||
return synced
|
||||
|
||||
def report_credential_usage(
|
||||
self,
|
||||
credential: CredentialObject,
|
||||
operation: str,
|
||||
status: str = "success",
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Report credential usage to Aden server.
|
||||
|
||||
Args:
|
||||
credential: The credential that was used.
|
||||
operation: Operation name (e.g., 'api_call').
|
||||
status: Operation status ('success', 'error').
|
||||
metadata: Additional metadata.
|
||||
"""
|
||||
if self._report_usage:
|
||||
self._client.report_usage(
|
||||
integration_id=credential.id,
|
||||
operation=operation,
|
||||
status=status,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def _update_credential_from_aden(
|
||||
self,
|
||||
credential: CredentialObject,
|
||||
aden_response: AdenCredentialResponse,
|
||||
) -> CredentialObject:
|
||||
"""Update credential object from Aden response."""
|
||||
# Update access token
|
||||
credential.keys["access_token"] = CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(aden_response.access_token),
|
||||
expires_at=aden_response.expires_at,
|
||||
)
|
||||
|
||||
# Update scopes if present
|
||||
if aden_response.scopes:
|
||||
credential.keys["scope"] = CredentialKey(
|
||||
name="scope",
|
||||
value=SecretStr(" ".join(aden_response.scopes)),
|
||||
)
|
||||
|
||||
# Mark as Aden-managed
|
||||
credential.keys["_aden_managed"] = CredentialKey(
|
||||
name="_aden_managed",
|
||||
value=SecretStr("true"),
|
||||
)
|
||||
|
||||
# Store integration type
|
||||
credential.keys["_integration_type"] = CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr(aden_response.integration_type),
|
||||
)
|
||||
|
||||
# Update timestamps
|
||||
credential.last_refreshed = datetime.now(UTC)
|
||||
credential.provider_id = self.provider_id
|
||||
|
||||
return credential
|
||||
|
||||
def _aden_response_to_credential(
|
||||
self,
|
||||
aden_response: AdenCredentialResponse,
|
||||
) -> CredentialObject:
|
||||
"""Convert Aden response to CredentialObject."""
|
||||
keys: dict[str, CredentialKey] = {
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(aden_response.access_token),
|
||||
expires_at=aden_response.expires_at,
|
||||
),
|
||||
"_aden_managed": CredentialKey(
|
||||
name="_aden_managed",
|
||||
value=SecretStr("true"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr(aden_response.integration_type),
|
||||
),
|
||||
}
|
||||
|
||||
if aden_response.scopes:
|
||||
keys["scope"] = CredentialKey(
|
||||
name="scope",
|
||||
value=SecretStr(" ".join(aden_response.scopes)),
|
||||
)
|
||||
|
||||
return CredentialObject(
|
||||
id=aden_response.integration_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys=keys,
|
||||
provider_id=self.provider_id,
|
||||
auto_refresh=True,
|
||||
)
|
||||
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Aden Cached Storage.
|
||||
|
||||
Storage backend that combines local cache with Aden server fallback.
|
||||
Provides offline resilience by caching credentials locally while
|
||||
keeping them synchronized with the Aden server.
|
||||
|
||||
Usage:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.storage import EncryptedFileStorage
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
AdenCachedStorage,
|
||||
)
|
||||
|
||||
# Configure
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url=os.environ["ADEN_API_URL"],
|
||||
api_key=os.environ["ADEN_API_KEY"],
|
||||
))
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
# Create cached storage
|
||||
storage = AdenCachedStorage(
|
||||
local_storage=EncryptedFileStorage(),
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300, # Re-check Aden every 5 minutes
|
||||
)
|
||||
|
||||
# Create store
|
||||
store = CredentialStore(
|
||||
storage=storage,
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Credentials automatically fetched from Aden on first access
|
||||
# Cached locally for 5 minutes
|
||||
# Falls back to cache if Aden is unreachable
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..storage import CredentialStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import CredentialObject
|
||||
from .provider import AdenSyncProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdenCachedStorage(CredentialStorage):
|
||||
"""
|
||||
Storage with local cache and Aden server fallback.
|
||||
|
||||
This storage provides:
|
||||
- **Reads**: Try local cache first, fallback to Aden if stale/missing
|
||||
- **Writes**: Always write to local cache
|
||||
- **Offline resilience**: Uses cached credentials when Aden is unreachable
|
||||
|
||||
The cache TTL determines how long to trust local credentials before
|
||||
checking with the Aden server for updates. This balances:
|
||||
- Performance (fewer network calls)
|
||||
- Freshness (tokens stay current)
|
||||
- Resilience (works during brief outages)
|
||||
|
||||
Usage:
|
||||
storage = AdenCachedStorage(
|
||||
local_storage=EncryptedFileStorage(),
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300, # 5 minutes
|
||||
)
|
||||
|
||||
store = CredentialStore(
|
||||
storage=storage,
|
||||
providers=[provider],
|
||||
)
|
||||
|
||||
# First access fetches from Aden
|
||||
# Subsequent accesses use cache until TTL expires
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_storage: CredentialStorage,
|
||||
aden_provider: AdenSyncProvider,
|
||||
cache_ttl_seconds: int = 300,
|
||||
prefer_local: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize Aden-cached storage.
|
||||
|
||||
Args:
|
||||
local_storage: Local storage backend for caching (e.g., EncryptedFileStorage).
|
||||
aden_provider: Provider for fetching from Aden server.
|
||||
cache_ttl_seconds: How long to trust local cache before checking Aden.
|
||||
Default is 300 seconds (5 minutes).
|
||||
prefer_local: If True, use local cache when available and fresh.
|
||||
If False, always check Aden first.
|
||||
"""
|
||||
self._local = local_storage
|
||||
self._aden_provider = aden_provider
|
||||
self._cache_ttl = timedelta(seconds=cache_ttl_seconds)
|
||||
self._prefer_local = prefer_local
|
||||
self._cache_timestamps: dict[str, datetime] = {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save credential to local cache.
|
||||
|
||||
Args:
|
||||
credential: The credential to save.
|
||||
"""
|
||||
self._local.save(credential)
|
||||
self._cache_timestamps[credential.id] = datetime.now(UTC)
|
||||
logger.debug(f"Cached credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential from cache, with Aden fallback.
|
||||
|
||||
The loading strategy depends on the `prefer_local` setting:
|
||||
|
||||
If prefer_local=True (default):
|
||||
1. Check if local cache exists and is fresh (within TTL)
|
||||
2. If fresh, return cached credential
|
||||
3. If stale or missing, fetch from Aden
|
||||
4. Update local cache with Aden response
|
||||
5. If Aden fails, fall back to stale cache
|
||||
|
||||
If prefer_local=False:
|
||||
1. Always try to fetch from Aden first
|
||||
2. Update local cache with response
|
||||
3. Fall back to local cache only if Aden fails
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
"""
|
||||
local_cred = self._local.load(credential_id)
|
||||
|
||||
# If we prefer local and have a fresh cache, use it
|
||||
if self._prefer_local and local_cred and self._is_cache_fresh(credential_id):
|
||||
logger.debug(f"Using cached credential '{credential_id}'")
|
||||
return local_cred
|
||||
|
||||
# Try to fetch from Aden
|
||||
try:
|
||||
aden_cred = self._aden_provider.fetch_from_aden(credential_id)
|
||||
if aden_cred:
|
||||
# Update local cache
|
||||
self.save(aden_cred)
|
||||
logger.debug(f"Fetched credential '{credential_id}' from Aden")
|
||||
return aden_cred
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch '{credential_id}' from Aden: {e}")
|
||||
|
||||
# Fall back to local cache if Aden fails
|
||||
if local_cred:
|
||||
logger.info(f"Using stale cached credential '{credential_id}'")
|
||||
return local_cred
|
||||
|
||||
# Return local credential if it exists (may be None)
|
||||
return local_cred
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Delete credential from local cache.
|
||||
|
||||
Note: This does NOT delete the credential from the Aden server.
|
||||
It only removes the local cache entry.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
True if credential existed and was deleted.
|
||||
"""
|
||||
self._cache_timestamps.pop(credential_id, None)
|
||||
return self._local.delete(credential_id)
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""
|
||||
List credentials from local cache.
|
||||
|
||||
Returns:
|
||||
List of credential IDs in local cache.
|
||||
"""
|
||||
return self._local.list_all()
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if credential exists in local cache.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
True if credential exists locally.
|
||||
"""
|
||||
return self._local.exists(credential_id)
|
||||
|
||||
def _is_cache_fresh(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if local cache is still fresh (within TTL).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
True if cache is fresh, False if stale or not cached.
|
||||
"""
|
||||
cached_at = self._cache_timestamps.get(credential_id)
|
||||
if cached_at is None:
|
||||
return False
|
||||
return datetime.now(UTC) - cached_at < self._cache_ttl
|
||||
|
||||
def invalidate_cache(self, credential_id: str) -> None:
|
||||
"""
|
||||
Invalidate cache for a specific credential.
|
||||
|
||||
The next load() call will fetch from Aden regardless of TTL.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
"""
|
||||
self._cache_timestamps.pop(credential_id, None)
|
||||
logger.debug(f"Invalidated cache for '{credential_id}'")
|
||||
|
||||
def invalidate_all(self) -> None:
|
||||
"""Invalidate all cache entries."""
|
||||
self._cache_timestamps.clear()
|
||||
logger.debug("Invalidated all cache entries")
|
||||
|
||||
def sync_all_from_aden(self) -> int:
|
||||
"""
|
||||
Sync all credentials from Aden server to local cache.
|
||||
|
||||
Fetches the list of available integrations from Aden and
|
||||
updates the local cache with current tokens.
|
||||
|
||||
Returns:
|
||||
Number of credentials synced.
|
||||
"""
|
||||
synced = 0
|
||||
|
||||
try:
|
||||
integrations = self._aden_provider._client.list_integrations()
|
||||
|
||||
for info in integrations:
|
||||
if info.status != "active":
|
||||
logger.warning(
|
||||
f"Skipping integration '{info.integration_id}': status={info.status}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
cred = self._aden_provider.fetch_from_aden(info.integration_id)
|
||||
if cred:
|
||||
self.save(cred)
|
||||
synced += 1
|
||||
logger.info(f"Synced credential '{info.integration_id}' from Aden")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync '{info.integration_id}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list integrations from Aden: {e}")
|
||||
|
||||
return synced
|
||||
|
||||
def get_cache_info(self) -> dict[str, dict]:
|
||||
"""
|
||||
Get cache status information for all credentials.
|
||||
|
||||
Returns:
|
||||
Dict mapping credential_id to cache info (cached_at, is_fresh, ttl_remaining).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
info = {}
|
||||
|
||||
for cred_id in self.list_all():
|
||||
cached_at = self._cache_timestamps.get(cred_id)
|
||||
if cached_at:
|
||||
ttl_remaining = (cached_at + self._cache_ttl - now).total_seconds()
|
||||
info[cred_id] = {
|
||||
"cached_at": cached_at.isoformat(),
|
||||
"is_fresh": ttl_remaining > 0,
|
||||
"ttl_remaining_seconds": max(0, ttl_remaining),
|
||||
}
|
||||
else:
|
||||
info[cred_id] = {
|
||||
"cached_at": None,
|
||||
"is_fresh": False,
|
||||
"ttl_remaining_seconds": 0,
|
||||
}
|
||||
|
||||
return info
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for Aden credential sync components."""
|
||||
@@ -0,0 +1,670 @@
|
||||
"""
|
||||
Tests for Aden credential sync components.
|
||||
|
||||
Tests cover:
|
||||
- AdenCredentialClient: HTTP client for Aden API
|
||||
- AdenSyncProvider: Provider that syncs with Aden
|
||||
- AdenCachedStorage: Storage with local cache + Aden fallback
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from framework.credentials import (
|
||||
CredentialKey,
|
||||
CredentialObject,
|
||||
CredentialStore,
|
||||
CredentialType,
|
||||
InMemoryStorage,
|
||||
)
|
||||
from framework.credentials.aden import (
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenClientError,
|
||||
AdenCredentialClient,
|
||||
AdenCredentialResponse,
|
||||
AdenIntegrationInfo,
|
||||
AdenRefreshError,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aden_config():
|
||||
"""Create a test Aden client config."""
|
||||
return AdenClientConfig(
|
||||
base_url="https://api.test-aden.com",
|
||||
api_key="test-api-key",
|
||||
tenant_id="test-tenant",
|
||||
timeout=5.0,
|
||||
retry_attempts=2,
|
||||
retry_delay=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(aden_config):
|
||||
"""Create a mock Aden client."""
|
||||
client = Mock(spec=AdenCredentialClient)
|
||||
client.config = aden_config
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aden_response():
|
||||
"""Create a sample Aden credential response."""
|
||||
return AdenCredentialResponse(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
access_token="test-access-token",
|
||||
token_type="Bearer",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
scopes=["crm.objects.contacts.read", "crm.objects.contacts.write"],
|
||||
metadata={"portal_id": "12345"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(mock_client):
|
||||
"""Create an AdenSyncProvider with mock client."""
|
||||
return AdenSyncProvider(
|
||||
client=mock_client,
|
||||
provider_id="test_aden",
|
||||
refresh_buffer_minutes=5,
|
||||
report_usage=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_storage():
|
||||
"""Create an in-memory storage for testing."""
|
||||
return InMemoryStorage()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cached_storage(local_storage, provider):
|
||||
"""Create an AdenCachedStorage for testing."""
|
||||
return AdenCachedStorage(
|
||||
local_storage=local_storage,
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=60,
|
||||
prefer_local=True,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AdenCredentialResponse Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenCredentialResponse:
|
||||
"""Tests for AdenCredentialResponse dataclass."""
|
||||
|
||||
def test_from_dict_basic(self):
|
||||
"""Test creating response from dict."""
|
||||
data = {
|
||||
"integration_id": "github",
|
||||
"integration_type": "github",
|
||||
"access_token": "ghp_xxxxx",
|
||||
}
|
||||
|
||||
response = AdenCredentialResponse.from_dict(data)
|
||||
|
||||
assert response.integration_id == "github"
|
||||
assert response.integration_type == "github"
|
||||
assert response.access_token == "ghp_xxxxx"
|
||||
assert response.token_type == "Bearer"
|
||||
assert response.expires_at is None
|
||||
assert response.scopes == []
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""Test creating response with all fields."""
|
||||
data = {
|
||||
"integration_id": "hubspot",
|
||||
"integration_type": "hubspot",
|
||||
"access_token": "token123",
|
||||
"token_type": "Bearer",
|
||||
"expires_at": "2026-01-28T15:30:00Z",
|
||||
"scopes": ["read", "write"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
|
||||
response = AdenCredentialResponse.from_dict(data)
|
||||
|
||||
assert response.integration_id == "hubspot"
|
||||
assert response.access_token == "token123"
|
||||
assert response.expires_at is not None
|
||||
assert response.scopes == ["read", "write"]
|
||||
assert response.metadata == {"key": "value"}
|
||||
|
||||
|
||||
class TestAdenIntegrationInfo:
|
||||
"""Tests for AdenIntegrationInfo dataclass."""
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating integration info from dict."""
|
||||
data = {
|
||||
"integration_id": "slack",
|
||||
"integration_type": "slack",
|
||||
"status": "active",
|
||||
"expires_at": "2026-02-01T00:00:00Z",
|
||||
}
|
||||
|
||||
info = AdenIntegrationInfo.from_dict(data)
|
||||
|
||||
assert info.integration_id == "slack"
|
||||
assert info.integration_type == "slack"
|
||||
assert info.status == "active"
|
||||
assert info.expires_at is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AdenSyncProvider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenSyncProvider:
|
||||
"""Tests for AdenSyncProvider."""
|
||||
|
||||
def test_provider_id(self, provider):
|
||||
"""Test provider ID."""
|
||||
assert provider.provider_id == "test_aden"
|
||||
|
||||
def test_supported_types(self, provider):
|
||||
"""Test supported credential types."""
|
||||
assert CredentialType.OAUTH2 in provider.supported_types
|
||||
assert CredentialType.BEARER_TOKEN in provider.supported_types
|
||||
|
||||
def test_can_handle_oauth2(self, provider):
|
||||
"""Test can_handle returns True for OAUTH2 credentials with matching provider_id."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
provider_id="test_aden",
|
||||
)
|
||||
|
||||
assert provider.can_handle(cred) is True
|
||||
|
||||
def test_can_handle_aden_managed(self, provider):
|
||||
"""Test can_handle returns True for Aden-managed credentials."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"_aden_managed": CredentialKey(
|
||||
name="_aden_managed",
|
||||
value=SecretStr("true"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.can_handle(cred) is True
|
||||
|
||||
def test_can_handle_wrong_type(self, provider):
|
||||
"""Test can_handle returns False for unsupported types."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={},
|
||||
)
|
||||
|
||||
assert provider.can_handle(cred) is False
|
||||
|
||||
def test_refresh_success(self, provider, mock_client, aden_response):
|
||||
"""Test successful credential refresh."""
|
||||
mock_client.request_refresh.return_value = aden_response
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("old-token"),
|
||||
)
|
||||
},
|
||||
provider_id="test_aden",
|
||||
)
|
||||
|
||||
refreshed = provider.refresh(cred)
|
||||
|
||||
assert refreshed.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
assert refreshed.keys["_aden_managed"].value.get_secret_value() == "true"
|
||||
assert refreshed.last_refreshed is not None
|
||||
mock_client.request_refresh.assert_called_once_with("hubspot")
|
||||
|
||||
def test_refresh_requires_reauth(self, provider, mock_client):
|
||||
"""Test refresh that requires re-authorization."""
|
||||
mock_client.request_refresh.side_effect = AdenRefreshError(
|
||||
"Token revoked",
|
||||
requires_reauthorization=True,
|
||||
reauthorization_url="https://aden.com/reauth",
|
||||
)
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
|
||||
from framework.credentials import CredentialRefreshError
|
||||
|
||||
with pytest.raises(CredentialRefreshError) as exc_info:
|
||||
provider.refresh(cred)
|
||||
|
||||
assert "re-authorization" in str(exc_info.value).lower()
|
||||
|
||||
def test_refresh_aden_unavailable_cached_valid(self, provider, mock_client):
|
||||
"""Test refresh falls back to cache when Aden is unavailable and token is valid."""
|
||||
mock_client.request_refresh.side_effect = AdenClientError("Connection failed")
|
||||
|
||||
# Token expires in 1 hour - still valid
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("cached-token"),
|
||||
expires_at=future,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Should return the cached credential instead of failing
|
||||
result = provider.refresh(cred)
|
||||
|
||||
assert result.keys["access_token"].value.get_secret_value() == "cached-token"
|
||||
|
||||
def test_should_refresh_expired(self, provider):
|
||||
"""Test should_refresh returns True for expired token."""
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=past,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.should_refresh(cred) is True
|
||||
|
||||
def test_should_refresh_within_buffer(self, provider):
|
||||
"""Test should_refresh returns True when within buffer."""
|
||||
# Expires in 3 minutes (buffer is 5 minutes)
|
||||
soon = datetime.now(UTC) + timedelta(minutes=3)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=soon,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.should_refresh(cred) is True
|
||||
|
||||
def test_should_refresh_still_valid(self, provider):
|
||||
"""Test should_refresh returns False for valid token."""
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=future,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.should_refresh(cred) is False
|
||||
|
||||
def test_fetch_from_aden(self, provider, mock_client, aden_response):
|
||||
"""Test fetching credential from Aden."""
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
cred = provider.fetch_from_aden("hubspot")
|
||||
|
||||
assert cred is not None
|
||||
assert cred.id == "hubspot"
|
||||
assert cred.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
assert cred.auto_refresh is True
|
||||
|
||||
def test_fetch_from_aden_not_found(self, provider, mock_client):
|
||||
"""Test fetch returns None when not found."""
|
||||
mock_client.get_credential.return_value = None
|
||||
|
||||
cred = provider.fetch_from_aden("nonexistent")
|
||||
|
||||
assert cred is None
|
||||
|
||||
def test_sync_all(self, provider, mock_client, aden_response):
|
||||
"""Test syncing all credentials."""
|
||||
mock_client.list_integrations.return_value = [
|
||||
AdenIntegrationInfo(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
status="active",
|
||||
),
|
||||
AdenIntegrationInfo(
|
||||
integration_id="github",
|
||||
integration_type="github",
|
||||
status="requires_reauth", # Should be skipped
|
||||
),
|
||||
]
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
store = CredentialStore(storage=InMemoryStorage())
|
||||
synced = provider.sync_all(store)
|
||||
|
||||
assert synced == 1 # Only active one was synced
|
||||
assert store.get_credential("hubspot") is not None
|
||||
|
||||
def test_validate_via_aden(self, provider, mock_client):
|
||||
"""Test validation via Aden introspection."""
|
||||
mock_client.validate_token.return_value = {"valid": True}
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
|
||||
assert provider.validate(cred) is True
|
||||
|
||||
def test_validate_fallback_to_local(self, provider, mock_client):
|
||||
"""Test validation falls back to local check when Aden fails."""
|
||||
mock_client.validate_token.side_effect = AdenClientError("Failed")
|
||||
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=future,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.validate(cred) is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AdenCachedStorage Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenCachedStorage:
|
||||
"""Tests for AdenCachedStorage."""
|
||||
|
||||
def test_save_updates_cache_timestamp(self, cached_storage):
|
||||
"""Test save updates cache timestamp."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert "test" in cached_storage._cache_timestamps
|
||||
assert cached_storage.exists("test")
|
||||
|
||||
def test_load_from_fresh_cache(self, cached_storage, local_storage):
|
||||
"""Test load returns cached credential when fresh."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("cached-token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Save to both local storage and update timestamp
|
||||
local_storage.save(cred)
|
||||
cached_storage._cache_timestamps["test"] = datetime.now(UTC)
|
||||
|
||||
loaded = cached_storage.load("test")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "cached-token"
|
||||
|
||||
def test_load_from_aden_when_stale(
|
||||
self, cached_storage, local_storage, provider, mock_client, aden_response
|
||||
):
|
||||
"""Test load fetches from Aden when cache is stale."""
|
||||
# Create stale cached credential
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("stale-token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
|
||||
# Set cache timestamp to be stale (2 minutes ago, TTL is 60 seconds)
|
||||
cached_storage._cache_timestamps["hubspot"] = datetime.now(UTC) - timedelta(minutes=2)
|
||||
|
||||
# Mock Aden response
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
loaded = cached_storage.load("hubspot")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
|
||||
def test_load_falls_back_to_stale_when_aden_fails(
|
||||
self, cached_storage, local_storage, provider, mock_client
|
||||
):
|
||||
"""Test load falls back to stale cache when Aden fails."""
|
||||
# Create stale cached credential
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("stale-token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
cached_storage._cache_timestamps["hubspot"] = datetime.now(UTC) - timedelta(minutes=2)
|
||||
|
||||
# Aden fails
|
||||
mock_client.get_credential.side_effect = AdenClientError("Connection failed")
|
||||
|
||||
loaded = cached_storage.load("hubspot")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "stale-token"
|
||||
|
||||
def test_delete_removes_cache_timestamp(self, cached_storage, local_storage):
|
||||
"""Test delete removes cache timestamp."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert "test" in cached_storage._cache_timestamps
|
||||
|
||||
cached_storage.delete("test")
|
||||
|
||||
assert "test" not in cached_storage._cache_timestamps
|
||||
assert not cached_storage.exists("test")
|
||||
|
||||
def test_invalidate_cache(self, cached_storage, local_storage):
|
||||
"""Test invalidate_cache removes timestamp."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
cached_storage.save(cred)
|
||||
|
||||
cached_storage.invalidate_cache("test")
|
||||
|
||||
assert "test" not in cached_storage._cache_timestamps
|
||||
# Credential still exists in local storage
|
||||
assert local_storage.exists("test")
|
||||
|
||||
def test_invalidate_all(self, cached_storage):
|
||||
"""Test invalidate_all clears all timestamps."""
|
||||
for i in range(3):
|
||||
cached_storage._cache_timestamps[f"test_{i}"] = datetime.now(UTC)
|
||||
|
||||
cached_storage.invalidate_all()
|
||||
|
||||
assert len(cached_storage._cache_timestamps) == 0
|
||||
|
||||
def test_is_cache_fresh(self, cached_storage):
|
||||
"""Test _is_cache_fresh logic."""
|
||||
# Fresh cache
|
||||
cached_storage._cache_timestamps["fresh"] = datetime.now(UTC)
|
||||
assert cached_storage._is_cache_fresh("fresh") is True
|
||||
|
||||
# Stale cache
|
||||
cached_storage._cache_timestamps["stale"] = datetime.now(UTC) - timedelta(minutes=5)
|
||||
assert cached_storage._is_cache_fresh("stale") is False
|
||||
|
||||
# No cache
|
||||
assert cached_storage._is_cache_fresh("nonexistent") is False
|
||||
|
||||
def test_get_cache_info(self, cached_storage, local_storage):
|
||||
"""Test get_cache_info returns status for all credentials."""
|
||||
# Add some credentials
|
||||
for name in ["fresh", "stale"]:
|
||||
cred = CredentialObject(
|
||||
id=name,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
|
||||
cached_storage._cache_timestamps["fresh"] = datetime.now(UTC)
|
||||
cached_storage._cache_timestamps["stale"] = datetime.now(UTC) - timedelta(minutes=5)
|
||||
|
||||
info = cached_storage.get_cache_info()
|
||||
|
||||
assert "fresh" in info
|
||||
assert info["fresh"]["is_fresh"] is True
|
||||
assert info["fresh"]["ttl_remaining_seconds"] > 0
|
||||
|
||||
assert "stale" in info
|
||||
assert info["stale"]["is_fresh"] is False
|
||||
assert info["stale"]["ttl_remaining_seconds"] == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenIntegration:
|
||||
"""Integration tests for Aden sync components."""
|
||||
|
||||
def test_full_workflow(self, mock_client, aden_response):
|
||||
"""Test full workflow: sync, get, refresh."""
|
||||
# Setup
|
||||
mock_client.list_integrations.return_value = [
|
||||
AdenIntegrationInfo(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
status="active",
|
||||
),
|
||||
]
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
mock_client.request_refresh.return_value = AdenCredentialResponse(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
access_token="refreshed-token",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=2),
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
provider = AdenSyncProvider(client=mock_client)
|
||||
storage = InMemoryStorage()
|
||||
store = CredentialStore(
|
||||
storage=storage,
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Initial sync
|
||||
synced = provider.sync_all(store)
|
||||
assert synced == 1
|
||||
|
||||
# Get credential
|
||||
cred = store.get_credential("hubspot")
|
||||
assert cred is not None
|
||||
assert cred.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
|
||||
# Simulate expiration
|
||||
cred.keys["access_token"] = CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("test-access-token"),
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1), # Expired
|
||||
)
|
||||
storage.save(cred)
|
||||
|
||||
# Refresh should be triggered
|
||||
refreshed = provider.refresh(cred)
|
||||
assert refreshed.keys["access_token"].value.get_secret_value() == "refreshed-token"
|
||||
|
||||
def test_cached_storage_with_store(self, mock_client, aden_response):
|
||||
"""Test AdenCachedStorage with CredentialStore."""
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
provider = AdenSyncProvider(client=mock_client)
|
||||
local_storage = InMemoryStorage()
|
||||
cached_storage = AdenCachedStorage(
|
||||
local_storage=local_storage,
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300,
|
||||
)
|
||||
|
||||
# First load fetches from Aden
|
||||
cred = cached_storage.load("hubspot")
|
||||
assert cred is not None
|
||||
mock_client.get_credential.assert_called_once()
|
||||
|
||||
# Second load uses cache
|
||||
mock_client.get_credential.reset_mock()
|
||||
cred2 = cached_storage.load("hubspot")
|
||||
assert cred2 is not None
|
||||
mock_client.get_credential.assert_not_called()
|
||||
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
Core data models for the credential store.
|
||||
|
||||
This module defines the key-vault structure where credentials are objects
|
||||
containing one or more keys (e.g., api_key, access_token, refresh_token).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class CredentialType(str, Enum):
|
||||
"""Types of credentials the store can manage."""
|
||||
|
||||
API_KEY = "api_key"
|
||||
"""Simple API key (e.g., Brave Search, OpenAI)"""
|
||||
|
||||
OAUTH2 = "oauth2"
|
||||
"""OAuth2 with refresh token support"""
|
||||
|
||||
BASIC_AUTH = "basic_auth"
|
||||
"""Username/password pair"""
|
||||
|
||||
BEARER_TOKEN = "bearer_token"
|
||||
"""JWT or bearer token without refresh"""
|
||||
|
||||
CUSTOM = "custom"
|
||||
"""User-defined credential type"""
|
||||
|
||||
|
||||
class CredentialKey(BaseModel):
|
||||
"""
|
||||
A single key within a credential object.
|
||||
|
||||
Example: 'api_key' within a 'brave_search' credential
|
||||
|
||||
Attributes:
|
||||
name: Key name (e.g., 'api_key', 'access_token')
|
||||
value: Secret value (SecretStr prevents accidental logging)
|
||||
expires_at: Optional expiration time
|
||||
metadata: Additional key-specific metadata
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: SecretStr
|
||||
expires_at: datetime | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this key has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now(UTC) >= self.expires_at
|
||||
|
||||
def get_secret_value(self) -> str:
|
||||
"""Get the actual secret value (use sparingly)."""
|
||||
return self.value.get_secret_value()
|
||||
|
||||
|
||||
class CredentialObject(BaseModel):
|
||||
"""
|
||||
A credential object containing one or more keys.
|
||||
|
||||
This is the key-vault structure where each credential can have
|
||||
multiple keys (e.g., access_token, refresh_token, expires_at).
|
||||
|
||||
Example:
|
||||
CredentialObject(
|
||||
id="github_oauth",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(name="access_token", value=SecretStr("ghp_xxx")),
|
||||
"refresh_token": CredentialKey(name="refresh_token", value=SecretStr("ghr_xxx")),
|
||||
},
|
||||
provider_id="oauth2"
|
||||
)
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier (e.g., 'brave_search', 'github_oauth')
|
||||
credential_type: Type of credential (API_KEY, OAUTH2, etc.)
|
||||
keys: Dictionary of key name to CredentialKey
|
||||
provider_id: ID of provider responsible for lifecycle management
|
||||
auto_refresh: Whether to automatically refresh when expired
|
||||
"""
|
||||
|
||||
id: str = Field(description="Unique identifier (e.g., 'brave_search', 'github_oauth')")
|
||||
credential_type: CredentialType = CredentialType.API_KEY
|
||||
keys: dict[str, CredentialKey] = Field(default_factory=dict)
|
||||
|
||||
# Lifecycle management
|
||||
provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of provider responsible for lifecycle (e.g., 'oauth2', 'static')",
|
||||
)
|
||||
last_refreshed: datetime | None = None
|
||||
auto_refresh: bool = False
|
||||
|
||||
# Usage tracking
|
||||
last_used: datetime | None = None
|
||||
use_count: int = 0
|
||||
|
||||
# Metadata
|
||||
description: str = ""
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
updated_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def get_key(self, key_name: str) -> str | None:
|
||||
"""
|
||||
Get a specific key's value.
|
||||
|
||||
Args:
|
||||
key_name: Name of the key to retrieve
|
||||
|
||||
Returns:
|
||||
The key's secret value, or None if not found
|
||||
"""
|
||||
key = self.keys.get(key_name)
|
||||
if key is None:
|
||||
return None
|
||||
return key.get_secret_value()
|
||||
|
||||
def set_key(
|
||||
self,
|
||||
key_name: str,
|
||||
value: str,
|
||||
expires_at: datetime | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set or update a key.
|
||||
|
||||
Args:
|
||||
key_name: Name of the key
|
||||
value: Secret value
|
||||
expires_at: Optional expiration time
|
||||
metadata: Optional key-specific metadata
|
||||
"""
|
||||
self.keys[key_name] = CredentialKey(
|
||||
name=key_name,
|
||||
value=SecretStr(value),
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.updated_at = datetime.now(UTC)
|
||||
|
||||
def has_key(self, key_name: str) -> bool:
|
||||
"""Check if a key exists."""
|
||||
return key_name in self.keys
|
||||
|
||||
@property
|
||||
def needs_refresh(self) -> bool:
|
||||
"""Check if any key is expired or near expiration."""
|
||||
for key in self.keys.values():
|
||||
if key.is_expired:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if credential has at least one non-expired key."""
|
||||
if not self.keys:
|
||||
return False
|
||||
return not all(key.is_expired for key in self.keys.values())
|
||||
|
||||
def record_usage(self) -> None:
|
||||
"""Record that this credential was used."""
|
||||
self.last_used = datetime.now(UTC)
|
||||
self.use_count += 1
|
||||
|
||||
def get_default_key(self) -> str | None:
|
||||
"""
|
||||
Get the default key value.
|
||||
|
||||
Priority: 'value' > 'api_key' > 'access_token' > first key
|
||||
|
||||
Returns:
|
||||
The default key's value, or None if no keys exist
|
||||
"""
|
||||
for key_name in ["value", "api_key", "access_token"]:
|
||||
if key_name in self.keys:
|
||||
return self.get_key(key_name)
|
||||
|
||||
if self.keys:
|
||||
first_key = next(iter(self.keys))
|
||||
return self.get_key(first_key)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class CredentialUsageSpec(BaseModel):
|
||||
"""
|
||||
Specification for how a tool uses credentials.
|
||||
|
||||
This implements the "bipartisan" model where the credential store
|
||||
just stores values, and tools define how those values are used
|
||||
in HTTP requests (headers, query params, body).
|
||||
|
||||
Example:
|
||||
CredentialUsageSpec(
|
||||
credential_id="brave_search",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Subscription-Token": "{{api_key}}"}
|
||||
)
|
||||
|
||||
CredentialUsageSpec(
|
||||
credential_id="github_oauth",
|
||||
required_keys=["access_token"],
|
||||
headers={"Authorization": "Bearer {{access_token}}"}
|
||||
)
|
||||
|
||||
Attributes:
|
||||
credential_id: ID of credential to use
|
||||
required_keys: Keys that must be present
|
||||
headers: Header templates with {{key}} placeholders
|
||||
query_params: Query parameter templates
|
||||
body_fields: Request body field templates
|
||||
"""
|
||||
|
||||
credential_id: str = Field(description="ID of credential to use (e.g., 'brave_search')")
|
||||
required_keys: list[str] = Field(default_factory=list, description="Keys that must be present")
|
||||
|
||||
# Injection templates (bipartisan model)
|
||||
headers: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Header templates (e.g., {'Authorization': 'Bearer {{access_token}}'})",
|
||||
)
|
||||
query_params: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Query param templates (e.g., {'api_key': '{{api_key}}'})",
|
||||
)
|
||||
body_fields: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Request body field templates",
|
||||
)
|
||||
|
||||
# Metadata
|
||||
required: bool = True
|
||||
description: str = ""
|
||||
help_url: str = ""
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class CredentialError(Exception):
|
||||
"""Base exception for credential-related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialNotFoundError(CredentialError):
|
||||
"""Raised when a referenced credential doesn't exist."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialKeyNotFoundError(CredentialError):
|
||||
"""Raised when a referenced key doesn't exist in a credential."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialRefreshError(CredentialError):
|
||||
"""Raised when credential refresh fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialValidationError(CredentialError):
|
||||
"""Raised when credential validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialDecryptionError(CredentialError):
|
||||
"""Raised when credential decryption fails."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
OAuth2 support for the credential store.
|
||||
|
||||
This module provides OAuth2 credential management with:
|
||||
- Token types and configuration (OAuth2Token, OAuth2Config)
|
||||
- Generic OAuth2 provider (BaseOAuth2Provider)
|
||||
- Token lifecycle management (TokenLifecycleManager)
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config
|
||||
|
||||
# Configure OAuth2 provider
|
||||
provider = BaseOAuth2Provider(OAuth2Config(
|
||||
token_url="https://oauth2.example.com/token",
|
||||
client_id="your-client-id",
|
||||
client_secret="your-client-secret",
|
||||
default_scopes=["read", "write"],
|
||||
))
|
||||
|
||||
# Create store with OAuth2 provider
|
||||
store = CredentialStore.with_encrypted_storage(
|
||||
providers=[provider] # defaults to ~/.hive/credentials
|
||||
)
|
||||
|
||||
# Get token using client credentials
|
||||
token = provider.client_credentials_grant()
|
||||
|
||||
# Save to store
|
||||
from core.framework.credentials import CredentialObject, CredentialKey, CredentialType
|
||||
from pydantic import SecretStr
|
||||
|
||||
store.save_credential(CredentialObject(
|
||||
id="my_api",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(token.access_token),
|
||||
expires_at=token.expires_at,
|
||||
),
|
||||
"refresh_token": CredentialKey(
|
||||
name="refresh_token",
|
||||
value=SecretStr(token.refresh_token),
|
||||
) if token.refresh_token else None,
|
||||
},
|
||||
provider_id="oauth2",
|
||||
auto_refresh=True,
|
||||
))
|
||||
|
||||
For advanced lifecycle management:
|
||||
from core.framework.credentials.oauth2 import TokenLifecycleManager
|
||||
|
||||
manager = TokenLifecycleManager(
|
||||
provider=provider,
|
||||
credential_id="my_api",
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Get valid token (auto-refreshes if needed)
|
||||
token = manager.sync_get_valid_token()
|
||||
headers = manager.get_request_headers()
|
||||
"""
|
||||
|
||||
from .base_provider import BaseOAuth2Provider
|
||||
from .hubspot_provider import HubSpotOAuth2Provider
|
||||
from .lifecycle import TokenLifecycleManager, TokenRefreshResult
|
||||
from .provider import (
|
||||
OAuth2Config,
|
||||
OAuth2Error,
|
||||
OAuth2Token,
|
||||
RefreshTokenInvalidError,
|
||||
TokenExpiredError,
|
||||
TokenPlacement,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"OAuth2Token",
|
||||
"OAuth2Config",
|
||||
"TokenPlacement",
|
||||
# Providers
|
||||
"BaseOAuth2Provider",
|
||||
"HubSpotOAuth2Provider",
|
||||
# Lifecycle
|
||||
"TokenLifecycleManager",
|
||||
"TokenRefreshResult",
|
||||
# Errors
|
||||
"OAuth2Error",
|
||||
"TokenExpiredError",
|
||||
"RefreshTokenInvalidError",
|
||||
]
|
||||
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Base OAuth2 provider implementation.
|
||||
|
||||
This module provides a generic OAuth2 provider that works with standard
|
||||
OAuth2 servers. OSS users can extend this class for custom providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from ..models import CredentialObject, CredentialRefreshError, CredentialType
|
||||
from ..provider import CredentialProvider
|
||||
from .provider import (
|
||||
OAuth2Config,
|
||||
OAuth2Error,
|
||||
OAuth2Token,
|
||||
TokenPlacement,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOAuth2Provider(CredentialProvider):
|
||||
"""
|
||||
Generic OAuth2 provider implementation.
|
||||
|
||||
Works with standard OAuth2 servers (RFC 6749). Override methods for
|
||||
provider-specific behavior.
|
||||
|
||||
Supported grant types:
|
||||
- Client Credentials: For server-to-server authentication
|
||||
- Refresh Token: For refreshing expired access tokens
|
||||
- Authorization Code: For user-authorized access (requires callback handling)
|
||||
|
||||
OSS users can extend this class for custom providers:
|
||||
|
||||
class GitHubOAuth2Provider(BaseOAuth2Provider):
|
||||
def __init__(self, client_id: str, client_secret: str):
|
||||
super().__init__(OAuth2Config(
|
||||
token_url="https://github.com/login/oauth/access_token",
|
||||
authorization_url="https://github.com/login/oauth/authorize",
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
default_scopes=["repo", "user"],
|
||||
))
|
||||
|
||||
def exchange_code(self, code: str, redirect_uri: str, **kwargs) -> OAuth2Token:
|
||||
# GitHub returns data as form-encoded by default
|
||||
# Override to handle this
|
||||
...
|
||||
|
||||
Example usage:
|
||||
provider = BaseOAuth2Provider(OAuth2Config(
|
||||
token_url="https://oauth2.example.com/token",
|
||||
client_id="my-client-id",
|
||||
client_secret="my-client-secret",
|
||||
))
|
||||
|
||||
# Get token using client credentials
|
||||
token = provider.client_credentials_grant()
|
||||
|
||||
# Refresh an expired token
|
||||
new_token = provider.refresh_token(old_token.refresh_token)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2Config, provider_id: str = "oauth2"):
|
||||
"""
|
||||
Initialize the OAuth2 provider.
|
||||
|
||||
Args:
|
||||
config: OAuth2 configuration
|
||||
provider_id: Unique identifier for this provider instance
|
||||
"""
|
||||
self.config = config
|
||||
self._provider_id = provider_id
|
||||
self._client: Any | None = None
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return self._provider_id
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import httpx
|
||||
|
||||
self._client = httpx.Client(timeout=self.config.request_timeout)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"OAuth2 provider requires 'httpx'. Install with: pip install httpx"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
def _close_client(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup HTTP client on deletion."""
|
||||
self._close_client()
|
||||
|
||||
# --- Grant Types ---
|
||||
|
||||
def get_authorization_url(
|
||||
self,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Generate authorization URL for user consent (Authorization Code flow).
|
||||
|
||||
Args:
|
||||
state: Anti-CSRF state parameter (should be random and verified)
|
||||
redirect_uri: Callback URL to receive the authorization code
|
||||
scopes: Requested scopes (defaults to config.default_scopes)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
URL to redirect user for authorization
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization_url is not configured
|
||||
"""
|
||||
if not self.config.authorization_url:
|
||||
raise ValueError("authorization_url not configured for this provider")
|
||||
|
||||
params = {
|
||||
"client_id": self.config.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
"scope": " ".join(scopes or self.config.default_scopes),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return f"{self.config.authorization_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Exchange authorization code for tokens (Authorization Code flow).
|
||||
|
||||
Args:
|
||||
code: Authorization code from callback
|
||||
redirect_uri: Same redirect_uri used in authorization request
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
OAuth2Token with access_token and optional refresh_token
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If token exchange fails
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
**self.config.extra_token_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return self._token_request(data)
|
||||
|
||||
def client_credentials_grant(
|
||||
self,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Obtain token using client credentials (Client Credentials flow).
|
||||
|
||||
This is for server-to-server authentication where no user is involved.
|
||||
|
||||
Args:
|
||||
scopes: Requested scopes (defaults to config.default_scopes)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
OAuth2Token (typically without refresh_token)
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If token request fails
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
**self.config.extra_token_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if scopes or self.config.default_scopes:
|
||||
data["scope"] = " ".join(scopes or self.config.default_scopes)
|
||||
|
||||
return self._token_request(data)
|
||||
|
||||
def refresh_access_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Refresh an expired access token (Refresh Token flow).
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token
|
||||
scopes: Scopes to request (defaults to original scopes)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
New OAuth2Token (may include new refresh_token)
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If refresh fails
|
||||
RefreshTokenInvalidError: If refresh token is revoked/invalid
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
**self.config.extra_token_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if scopes:
|
||||
data["scope"] = " ".join(scopes)
|
||||
|
||||
return self._token_request(data)
|
||||
|
||||
def revoke_token(
|
||||
self,
|
||||
token: str,
|
||||
token_type_hint: str = "access_token",
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke a token (RFC 7009).
|
||||
|
||||
Args:
|
||||
token: The token to revoke
|
||||
token_type_hint: "access_token" or "refresh_token"
|
||||
|
||||
Returns:
|
||||
True if revocation succeeded
|
||||
"""
|
||||
if not self.config.revocation_url:
|
||||
logger.warning("revocation_url not configured, cannot revoke token")
|
||||
return False
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = client.post(
|
||||
self.config.revocation_url,
|
||||
data={
|
||||
"token": token,
|
||||
"token_type_hint": token_type_hint,
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
},
|
||||
headers={"Accept": "application/json", **self.config.extra_headers},
|
||||
)
|
||||
# RFC 7009: 200 indicates success (even if token was already invalid)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Token revocation failed: {e}")
|
||||
return False
|
||||
|
||||
# --- CredentialProvider Interface ---
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Refresh a credential using its refresh token.
|
||||
|
||||
Implements CredentialProvider.refresh().
|
||||
|
||||
Args:
|
||||
credential: The credential to refresh
|
||||
|
||||
Returns:
|
||||
Updated credential with new access_token
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails
|
||||
"""
|
||||
refresh_tok = credential.get_key("refresh_token")
|
||||
if not refresh_tok:
|
||||
raise CredentialRefreshError(f"Credential '{credential.id}' has no refresh_token")
|
||||
|
||||
try:
|
||||
new_token = self.refresh_access_token(refresh_tok)
|
||||
except OAuth2Error as e:
|
||||
if e.error == "invalid_grant":
|
||||
raise CredentialRefreshError(
|
||||
f"Refresh token for '{credential.id}' is invalid or revoked. "
|
||||
"Re-authorization required."
|
||||
) from e
|
||||
raise CredentialRefreshError(f"Failed to refresh '{credential.id}': {e}") from e
|
||||
|
||||
# Update credential
|
||||
credential.set_key("access_token", new_token.access_token, expires_at=new_token.expires_at)
|
||||
|
||||
# Update refresh token if a new one was issued
|
||||
if new_token.refresh_token and new_token.refresh_token != refresh_tok:
|
||||
credential.set_key("refresh_token", new_token.refresh_token)
|
||||
|
||||
credential.last_refreshed = datetime.now(UTC)
|
||||
logger.info(f"Refreshed OAuth2 credential '{credential.id}'")
|
||||
|
||||
return credential
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate that credential has a valid (non-expired) access_token.
|
||||
|
||||
Args:
|
||||
credential: The credential to validate
|
||||
|
||||
Returns:
|
||||
True if credential has valid access_token
|
||||
"""
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
return not access_key.is_expired
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if credential should be refreshed.
|
||||
|
||||
Returns True if access_token is expired or within 5 minutes of expiry.
|
||||
"""
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
if access_key.expires_at is None:
|
||||
return False
|
||||
|
||||
buffer = timedelta(minutes=5)
|
||||
return datetime.now(UTC) >= (access_key.expires_at - buffer)
|
||||
|
||||
def revoke(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Revoke all tokens in a credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to revoke
|
||||
|
||||
Returns:
|
||||
True if all revocations succeeded
|
||||
"""
|
||||
success = True
|
||||
|
||||
# Revoke access token
|
||||
access_token = credential.get_key("access_token")
|
||||
if access_token:
|
||||
if not self.revoke_token(access_token, "access_token"):
|
||||
success = False
|
||||
|
||||
# Revoke refresh token
|
||||
refresh_token = credential.get_key("refresh_token")
|
||||
if refresh_token:
|
||||
if not self.revoke_token(refresh_token, "refresh_token"):
|
||||
success = False
|
||||
|
||||
return success
|
||||
|
||||
# --- Token Request Helpers ---
|
||||
|
||||
def _token_request(self, data: dict[str, Any]) -> OAuth2Token:
|
||||
"""
|
||||
Make a token request to the OAuth2 server.
|
||||
|
||||
Args:
|
||||
data: Form data for the token request
|
||||
|
||||
Returns:
|
||||
OAuth2Token from the response
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If request fails or returns an error
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
**self.config.extra_headers,
|
||||
}
|
||||
|
||||
response = client.post(self.config.token_url, data=data, headers=headers)
|
||||
|
||||
# Parse response
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
response_data = response.json()
|
||||
else:
|
||||
# Some providers (like GitHub) may return form-encoded
|
||||
response_data = self._parse_form_response(response.text)
|
||||
|
||||
# Check for error
|
||||
if response.status_code != 200 or "error" in response_data:
|
||||
error = response_data.get("error", "unknown_error")
|
||||
description = response_data.get("error_description", response.text)
|
||||
raise OAuth2Error(
|
||||
error=error, description=description, status_code=response.status_code
|
||||
)
|
||||
|
||||
return OAuth2Token.from_token_response(response_data)
|
||||
|
||||
def _parse_form_response(self, text: str) -> dict[str, str]:
|
||||
"""Parse form-encoded response (some providers use this instead of JSON)."""
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
parsed = parse_qs(text)
|
||||
return {k: v[0] if len(v) == 1 else v for k, v in parsed.items()}
|
||||
|
||||
# --- Token Formatting for Requests ---
|
||||
|
||||
def format_for_request(self, token: OAuth2Token) -> dict[str, Any]:
|
||||
"""
|
||||
Format token for use in HTTP requests (bipartisan model).
|
||||
|
||||
Args:
|
||||
token: The OAuth2 token
|
||||
|
||||
Returns:
|
||||
Dict with 'headers', 'params', or 'data' keys as appropriate
|
||||
"""
|
||||
placement = self.config.token_placement
|
||||
|
||||
if placement == TokenPlacement.HEADER_BEARER:
|
||||
return {"headers": {"Authorization": f"{token.token_type} {token.access_token}"}}
|
||||
|
||||
elif placement == TokenPlacement.HEADER_CUSTOM:
|
||||
header_name = self.config.custom_header_name or "X-Access-Token"
|
||||
return {"headers": {header_name: token.access_token}}
|
||||
|
||||
elif placement == TokenPlacement.QUERY_PARAM:
|
||||
return {"params": {self.config.query_param_name: token.access_token}}
|
||||
|
||||
elif placement == TokenPlacement.BODY_PARAM:
|
||||
return {"data": {"access_token": token.access_token}}
|
||||
|
||||
return {}
|
||||
|
||||
def format_credential_for_request(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""
|
||||
Format a credential for use in HTTP requests.
|
||||
|
||||
Args:
|
||||
credential: The credential containing access_token
|
||||
|
||||
Returns:
|
||||
Dict with 'headers', 'params', or 'data' keys as appropriate
|
||||
"""
|
||||
access_token = credential.get_key("access_token")
|
||||
if not access_token:
|
||||
return {}
|
||||
|
||||
token = OAuth2Token(
|
||||
access_token=access_token,
|
||||
token_type=credential.keys.get("token_type", "Bearer") or "Bearer",
|
||||
)
|
||||
|
||||
return self.format_for_request(token)
|
||||
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
HubSpot-specific OAuth2 provider.
|
||||
|
||||
Pre-configured for HubSpot's OAuth2 endpoints and CRM scopes.
|
||||
Extends BaseOAuth2Provider for HubSpot-specific behavior.
|
||||
|
||||
Usage:
|
||||
provider = HubSpotOAuth2Provider(
|
||||
client_id="your-client-id",
|
||||
client_secret="your-client-secret",
|
||||
)
|
||||
|
||||
# Use with credential store
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage(), # defaults to ~/.hive/credentials
|
||||
providers=[provider],
|
||||
)
|
||||
|
||||
See: https://developers.hubspot.com/docs/api/oauth-quickstart-guide
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from ..models import CredentialObject, CredentialType
|
||||
from .base_provider import BaseOAuth2Provider
|
||||
from .provider import OAuth2Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HubSpot OAuth2 endpoints
|
||||
HUBSPOT_TOKEN_URL = "https://api.hubapi.com/oauth/v1/token"
|
||||
HUBSPOT_AUTHORIZATION_URL = "https://app.hubspot.com/oauth/authorize"
|
||||
|
||||
# Default CRM scopes for contacts, companies, and deals
|
||||
HUBSPOT_DEFAULT_SCOPES = [
|
||||
"crm.objects.contacts.read",
|
||||
"crm.objects.contacts.write",
|
||||
"crm.objects.companies.read",
|
||||
"crm.objects.companies.write",
|
||||
"crm.objects.deals.read",
|
||||
"crm.objects.deals.write",
|
||||
]
|
||||
|
||||
|
||||
class HubSpotOAuth2Provider(BaseOAuth2Provider):
|
||||
"""
|
||||
HubSpot OAuth2 provider with pre-configured endpoints.
|
||||
|
||||
Handles HubSpot-specific OAuth2 behavior:
|
||||
- Pre-configured token and authorization URLs
|
||||
- Default CRM scopes for contacts, companies, and deals
|
||||
- Token validation via HubSpot API
|
||||
|
||||
Example:
|
||||
provider = HubSpotOAuth2Provider(
|
||||
client_id="your-hubspot-client-id",
|
||||
client_secret="your-hubspot-client-secret",
|
||||
scopes=["crm.objects.contacts.read"], # Override default scopes
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
scopes: list[str] | None = None,
|
||||
):
|
||||
config = OAuth2Config(
|
||||
token_url=HUBSPOT_TOKEN_URL,
|
||||
authorization_url=HUBSPOT_AUTHORIZATION_URL,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
default_scopes=scopes or HUBSPOT_DEFAULT_SCOPES,
|
||||
)
|
||||
super().__init__(config, provider_id="hubspot_oauth2")
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.OAUTH2]
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate HubSpot credential by making a lightweight API call.
|
||||
|
||||
Tests the access token against the contacts endpoint with limit=1.
|
||||
"""
|
||||
access_token = credential.get_key("access_token")
|
||||
if not access_token:
|
||||
return False
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = client.get(
|
||||
"https://api.hubapi.com/crm/v3/objects/contacts",
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
params={"limit": "1"},
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _parse_token_response(self, response_data: dict[str, Any]) -> Any:
|
||||
"""Parse HubSpot token response."""
|
||||
from .provider import OAuth2Token
|
||||
|
||||
return OAuth2Token.from_token_response(response_data)
|
||||
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Token lifecycle management for OAuth2 credentials.
|
||||
|
||||
This module provides the TokenLifecycleManager which coordinates
|
||||
automatic token refresh with the credential store.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from ..models import CredentialKey, CredentialObject, CredentialType
|
||||
from .base_provider import BaseOAuth2Provider
|
||||
from .provider import OAuth2Token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..store import CredentialStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenRefreshResult:
|
||||
"""Result of a token refresh operation."""
|
||||
|
||||
success: bool
|
||||
token: OAuth2Token | None = None
|
||||
error: str | None = None
|
||||
needs_reauthorization: bool = False
|
||||
|
||||
|
||||
class TokenLifecycleManager:
|
||||
"""
|
||||
Manages the complete lifecycle of OAuth2 tokens.
|
||||
|
||||
Responsibilities:
|
||||
- Coordinate with CredentialStore for persistence
|
||||
- Automatically refresh expired tokens
|
||||
- Handle refresh failures gracefully
|
||||
- Provide callbacks for monitoring
|
||||
|
||||
This class is useful when you need more control over token management
|
||||
than the basic auto-refresh in CredentialStore provides.
|
||||
|
||||
Usage:
|
||||
manager = TokenLifecycleManager(
|
||||
provider=github_provider,
|
||||
credential_id="github_oauth",
|
||||
store=credential_store,
|
||||
)
|
||||
|
||||
# Get valid token (auto-refreshes if needed)
|
||||
token = await manager.get_valid_token()
|
||||
|
||||
# Use token
|
||||
headers = provider.format_for_request(token)
|
||||
|
||||
Synchronous usage:
|
||||
# For synchronous code, use sync_ methods
|
||||
token = manager.sync_get_valid_token()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: BaseOAuth2Provider,
|
||||
credential_id: str,
|
||||
store: CredentialStore,
|
||||
refresh_buffer_minutes: int = 5,
|
||||
on_token_refreshed: Callable[[OAuth2Token], None] | None = None,
|
||||
on_refresh_failed: Callable[[str], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
|
||||
Args:
|
||||
provider: OAuth2 provider for token operations
|
||||
credential_id: ID of the credential in the store
|
||||
store: Credential store for persistence
|
||||
refresh_buffer_minutes: Minutes before expiry to trigger refresh
|
||||
on_token_refreshed: Callback when token is refreshed
|
||||
on_refresh_failed: Callback when refresh fails
|
||||
"""
|
||||
self.provider = provider
|
||||
self.credential_id = credential_id
|
||||
self.store = store
|
||||
self.refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
|
||||
self.on_token_refreshed = on_token_refreshed
|
||||
self.on_refresh_failed = on_refresh_failed
|
||||
|
||||
# In-memory cache for performance
|
||||
self._cached_token: OAuth2Token | None = None
|
||||
self._cache_time: datetime | None = None
|
||||
|
||||
# --- Async Token Access ---
|
||||
|
||||
async def get_valid_token(self) -> OAuth2Token | None:
|
||||
"""
|
||||
Get a valid access token, refreshing if necessary.
|
||||
|
||||
This is the main entry point for async code.
|
||||
|
||||
Returns:
|
||||
Valid OAuth2Token or None if unavailable
|
||||
"""
|
||||
# Check cache first
|
||||
if self._cached_token and not self._needs_refresh(self._cached_token):
|
||||
return self._cached_token
|
||||
|
||||
# Load from store
|
||||
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
# Convert to OAuth2Token
|
||||
token = self._credential_to_token(credential)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
# Refresh if needed
|
||||
if self._needs_refresh(token):
|
||||
result = await self._async_refresh_token(credential)
|
||||
if result.success and result.token:
|
||||
token = result.token
|
||||
elif result.needs_reauthorization:
|
||||
logger.warning(f"Token for {self.credential_id} needs reauthorization")
|
||||
return None
|
||||
else:
|
||||
# Use existing token if still technically valid
|
||||
if token.is_expired:
|
||||
return None
|
||||
logger.warning(f"Refresh failed for {self.credential_id}, using existing token")
|
||||
|
||||
self._cached_token = token
|
||||
self._cache_time = datetime.now(UTC)
|
||||
return token
|
||||
|
||||
async def acquire_token_client_credentials(
|
||||
self,
|
||||
scopes: list[str] | None = None,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Acquire a new token using client credentials flow.
|
||||
|
||||
For service-to-service authentication.
|
||||
|
||||
Args:
|
||||
scopes: Scopes to request
|
||||
|
||||
Returns:
|
||||
New OAuth2Token
|
||||
"""
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
token = await loop.run_in_executor(
|
||||
None, lambda: self.provider.client_credentials_grant(scopes=scopes)
|
||||
)
|
||||
|
||||
self._save_token_to_store(token)
|
||||
self._cached_token = token
|
||||
return token
|
||||
|
||||
async def revoke(self) -> bool:
|
||||
"""
|
||||
Revoke tokens and clear from store.
|
||||
|
||||
Returns:
|
||||
True if revocation succeeded
|
||||
"""
|
||||
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
||||
if credential:
|
||||
self.provider.revoke(credential)
|
||||
|
||||
self.store.delete_credential(self.credential_id)
|
||||
self._cached_token = None
|
||||
return True
|
||||
|
||||
# --- Synchronous Token Access ---
|
||||
|
||||
def sync_get_valid_token(self) -> OAuth2Token | None:
|
||||
"""
|
||||
Synchronous version of get_valid_token().
|
||||
|
||||
For use in synchronous code.
|
||||
"""
|
||||
# Check cache
|
||||
if self._cached_token and not self._needs_refresh(self._cached_token):
|
||||
return self._cached_token
|
||||
|
||||
# Load from store
|
||||
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
token = self._credential_to_token(credential)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
# Refresh if needed
|
||||
if self._needs_refresh(token):
|
||||
result = self._sync_refresh_token(credential)
|
||||
if result.success and result.token:
|
||||
token = result.token
|
||||
elif result.needs_reauthorization:
|
||||
logger.warning(f"Token for {self.credential_id} needs reauthorization")
|
||||
return None
|
||||
else:
|
||||
if token.is_expired:
|
||||
return None
|
||||
|
||||
self._cached_token = token
|
||||
self._cache_time = datetime.now(UTC)
|
||||
return token
|
||||
|
||||
def sync_acquire_token_client_credentials(
|
||||
self,
|
||||
scopes: list[str] | None = None,
|
||||
) -> OAuth2Token:
|
||||
"""Synchronous version of acquire_token_client_credentials()."""
|
||||
token = self.provider.client_credentials_grant(scopes=scopes)
|
||||
self._save_token_to_store(token)
|
||||
self._cached_token = token
|
||||
return token
|
||||
|
||||
# --- Helper Methods ---
|
||||
|
||||
def _needs_refresh(self, token: OAuth2Token) -> bool:
|
||||
"""Check if token needs refresh."""
|
||||
if token.expires_at is None:
|
||||
return False
|
||||
return datetime.now(UTC) >= (token.expires_at - self.refresh_buffer)
|
||||
|
||||
def _credential_to_token(self, credential: CredentialObject) -> OAuth2Token | None:
|
||||
"""Convert credential to OAuth2Token."""
|
||||
access_token = credential.get_key("access_token")
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
expires_at = None
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key:
|
||||
expires_at = access_key.expires_at
|
||||
|
||||
return OAuth2Token(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_at=expires_at,
|
||||
refresh_token=credential.get_key("refresh_token"),
|
||||
scope=credential.get_key("scope"),
|
||||
)
|
||||
|
||||
def _save_token_to_store(self, token: OAuth2Token) -> None:
|
||||
"""Save token to credential store."""
|
||||
credential = CredentialObject(
|
||||
id=self.credential_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(token.access_token),
|
||||
expires_at=token.expires_at,
|
||||
),
|
||||
},
|
||||
provider_id=self.provider.provider_id,
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
if token.refresh_token:
|
||||
credential.keys["refresh_token"] = CredentialKey(
|
||||
name="refresh_token",
|
||||
value=SecretStr(token.refresh_token),
|
||||
)
|
||||
|
||||
if token.scope:
|
||||
credential.keys["scope"] = CredentialKey(
|
||||
name="scope",
|
||||
value=SecretStr(token.scope),
|
||||
)
|
||||
|
||||
self.store.save_credential(credential)
|
||||
|
||||
async def _async_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
|
||||
"""Async wrapper for token refresh."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, lambda: self._sync_refresh_token(credential))
|
||||
|
||||
def _sync_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
|
||||
"""Synchronously refresh token."""
|
||||
refresh_token = credential.get_key("refresh_token")
|
||||
if not refresh_token:
|
||||
return TokenRefreshResult(
|
||||
success=False,
|
||||
error="No refresh token available",
|
||||
needs_reauthorization=True,
|
||||
)
|
||||
|
||||
try:
|
||||
new_token = self.provider.refresh_access_token(refresh_token)
|
||||
|
||||
# Save to store
|
||||
self._save_token_to_store(new_token)
|
||||
|
||||
# Notify callback
|
||||
if self.on_token_refreshed:
|
||||
self.on_token_refreshed(new_token)
|
||||
|
||||
logger.info(f"Token refreshed for {self.credential_id}")
|
||||
return TokenRefreshResult(success=True, token=new_token)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Check for refresh token revocation
|
||||
if "invalid_grant" in error_msg.lower():
|
||||
return TokenRefreshResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
needs_reauthorization=True,
|
||||
)
|
||||
|
||||
if self.on_refresh_failed:
|
||||
self.on_refresh_failed(error_msg)
|
||||
|
||||
logger.error(f"Token refresh failed for {self.credential_id}: {e}")
|
||||
return TokenRefreshResult(success=False, error=error_msg)
|
||||
|
||||
def invalidate_cache(self) -> None:
|
||||
"""Clear cached token."""
|
||||
self._cached_token = None
|
||||
self._cache_time = None
|
||||
|
||||
# --- Convenience Methods ---
|
||||
|
||||
def get_request_headers(self) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for HTTP request with current token.
|
||||
|
||||
Returns empty dict if no valid token.
|
||||
"""
|
||||
token = self.sync_get_valid_token()
|
||||
if token is None:
|
||||
return {}
|
||||
|
||||
result = self.provider.format_for_request(token)
|
||||
return result.get("headers", {})
|
||||
|
||||
def get_request_kwargs(self) -> dict:
|
||||
"""
|
||||
Get kwargs for HTTP request (headers, params, etc.).
|
||||
|
||||
Returns empty dict if no valid token.
|
||||
"""
|
||||
token = self.sync_get_valid_token()
|
||||
if token is None:
|
||||
return {}
|
||||
|
||||
return self.provider.format_for_request(token)
|
||||
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
OAuth2 types and configuration.
|
||||
|
||||
This module defines the core OAuth2 data structures:
|
||||
- OAuth2Token: Represents an access token with metadata
|
||||
- OAuth2Config: Configuration for OAuth2 endpoints
|
||||
- TokenPlacement: Where to place tokens in requests
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TokenPlacement(str, Enum):
|
||||
"""Where to place the access token in HTTP requests."""
|
||||
|
||||
HEADER_BEARER = "header_bearer"
|
||||
"""Authorization: Bearer <token> (most common)"""
|
||||
|
||||
HEADER_CUSTOM = "header_custom"
|
||||
"""Custom header name (e.g., X-Access-Token)"""
|
||||
|
||||
QUERY_PARAM = "query_param"
|
||||
"""Query parameter (e.g., ?access_token=<token>)"""
|
||||
|
||||
BODY_PARAM = "body_param"
|
||||
"""Form body parameter"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuth2Token:
|
||||
"""
|
||||
Represents an OAuth2 token with metadata.
|
||||
|
||||
Attributes:
|
||||
access_token: The access token string
|
||||
token_type: Token type (usually "Bearer")
|
||||
expires_at: When the token expires
|
||||
refresh_token: Optional refresh token
|
||||
scope: Granted scopes (space-separated)
|
||||
raw_response: Original token response from server
|
||||
"""
|
||||
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_at: datetime | None = None
|
||||
refresh_token: str | None = None
|
||||
scope: str | None = None
|
||||
raw_response: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
Check if token is expired.
|
||||
|
||||
Uses a 5-minute buffer to account for clock skew and
|
||||
request latency.
|
||||
"""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
buffer = timedelta(minutes=5)
|
||||
return datetime.now(UTC) >= (self.expires_at - buffer)
|
||||
|
||||
@property
|
||||
def can_refresh(self) -> bool:
|
||||
"""Check if token can be refreshed (has refresh_token)."""
|
||||
return self.refresh_token is not None and self.refresh_token.strip() != ""
|
||||
|
||||
@property
|
||||
def expires_in_seconds(self) -> int | None:
|
||||
"""Get seconds until expiration, or None if no expiration."""
|
||||
if self.expires_at is None:
|
||||
return None
|
||||
delta = self.expires_at - datetime.now(UTC)
|
||||
return max(0, int(delta.total_seconds()))
|
||||
|
||||
@classmethod
|
||||
def from_token_response(cls, data: dict[str, Any]) -> OAuth2Token:
|
||||
"""
|
||||
Create OAuth2Token from an OAuth2 token endpoint response.
|
||||
|
||||
Args:
|
||||
data: Token response JSON (access_token, token_type, expires_in, etc.)
|
||||
|
||||
Returns:
|
||||
OAuth2Token instance
|
||||
"""
|
||||
expires_at = None
|
||||
if "expires_in" in data:
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=data["expires_in"])
|
||||
|
||||
return cls(
|
||||
access_token=data["access_token"],
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_at=expires_at,
|
||||
refresh_token=data.get("refresh_token"),
|
||||
scope=data.get("scope"),
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuth2Config:
|
||||
"""
|
||||
Configuration for an OAuth2 provider.
|
||||
|
||||
This contains all the information needed to perform OAuth2 operations
|
||||
for a specific provider (GitHub, Google, Salesforce, etc.).
|
||||
|
||||
Attributes:
|
||||
token_url: URL for token endpoint (required)
|
||||
authorization_url: URL for authorization endpoint (optional, for auth code flow)
|
||||
revocation_url: URL for token revocation (optional)
|
||||
introspection_url: URL for token introspection (optional)
|
||||
client_id: OAuth2 client ID
|
||||
client_secret: OAuth2 client secret
|
||||
default_scopes: Default scopes to request
|
||||
token_placement: How to include token in requests
|
||||
custom_header_name: Header name when using HEADER_CUSTOM placement
|
||||
query_param_name: Query param name when using QUERY_PARAM placement
|
||||
extra_token_params: Additional parameters for token requests
|
||||
request_timeout: Timeout for HTTP requests in seconds
|
||||
|
||||
Example:
|
||||
config = OAuth2Config(
|
||||
token_url="https://github.com/login/oauth/access_token",
|
||||
authorization_url="https://github.com/login/oauth/authorize",
|
||||
client_id="your-client-id",
|
||||
client_secret="your-client-secret",
|
||||
default_scopes=["repo", "user"],
|
||||
)
|
||||
"""
|
||||
|
||||
# Endpoints (only token_url is strictly required)
|
||||
token_url: str
|
||||
authorization_url: str | None = None
|
||||
revocation_url: str | None = None
|
||||
introspection_url: str | None = None
|
||||
|
||||
# Client credentials
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
|
||||
# Scopes
|
||||
default_scopes: list[str] = field(default_factory=list)
|
||||
|
||||
# Token placement for API calls (bipartisan model)
|
||||
token_placement: TokenPlacement = TokenPlacement.HEADER_BEARER
|
||||
custom_header_name: str | None = None
|
||||
query_param_name: str = "access_token"
|
||||
|
||||
# Request configuration
|
||||
extra_token_params: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout: float = 30.0
|
||||
|
||||
# Additional headers for token requests
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration."""
|
||||
if not self.token_url:
|
||||
raise ValueError("token_url is required")
|
||||
|
||||
if self.token_placement == TokenPlacement.HEADER_CUSTOM and not self.custom_header_name:
|
||||
raise ValueError("custom_header_name is required when using HEADER_CUSTOM placement")
|
||||
|
||||
|
||||
class OAuth2Error(Exception):
|
||||
"""
|
||||
OAuth2 protocol error.
|
||||
|
||||
Attributes:
|
||||
error: OAuth2 error code (e.g., 'invalid_grant', 'invalid_client')
|
||||
description: Human-readable error description
|
||||
status_code: HTTP status code from the response
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: str,
|
||||
description: str = "",
|
||||
status_code: int = 0,
|
||||
):
|
||||
self.error = error
|
||||
self.description = description
|
||||
self.status_code = status_code
|
||||
super().__init__(f"{error}: {description}" if description else error)
|
||||
|
||||
|
||||
class TokenExpiredError(OAuth2Error):
|
||||
"""Raised when a token has expired and cannot be used."""
|
||||
|
||||
def __init__(self, credential_id: str):
|
||||
super().__init__(
|
||||
error="token_expired",
|
||||
description=f"Token for '{credential_id}' has expired",
|
||||
)
|
||||
self.credential_id = credential_id
|
||||
|
||||
|
||||
class RefreshTokenInvalidError(OAuth2Error):
|
||||
"""Raised when the refresh token is invalid or revoked."""
|
||||
|
||||
def __init__(self, credential_id: str, reason: str = ""):
|
||||
description = f"Refresh token for '{credential_id}' is invalid"
|
||||
if reason:
|
||||
description += f": {reason}"
|
||||
super().__init__(error="invalid_grant", description=description)
|
||||
self.credential_id = credential_id
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Provider interface for credential lifecycle management.
|
||||
|
||||
Providers handle credential lifecycle operations:
|
||||
- Refresh: Obtain new tokens when expired
|
||||
- Validate: Check if credentials are still working
|
||||
- Revoke: Invalidate credentials when no longer needed
|
||||
|
||||
OSS users can implement custom providers by subclassing CredentialProvider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from .models import CredentialObject, CredentialRefreshError, CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialProvider(ABC):
|
||||
"""
|
||||
Abstract base class for credential providers.
|
||||
|
||||
Providers handle credential lifecycle operations:
|
||||
- refresh(): Obtain new tokens when expired
|
||||
- validate(): Check if credentials are still working
|
||||
- should_refresh(): Determine if a credential needs refresh
|
||||
- revoke(): Invalidate credentials (optional)
|
||||
|
||||
Example custom provider:
|
||||
class MyCustomProvider(CredentialProvider):
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return "my_custom"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> List[CredentialType]:
|
||||
return [CredentialType.CUSTOM]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
# Custom refresh logic
|
||||
new_token = my_api.refresh(credential.get_key("api_key"))
|
||||
credential.set_key("access_token", new_token)
|
||||
return credential
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
token = credential.get_key("access_token")
|
||||
return my_api.validate(token)
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_id(self) -> str:
|
||||
"""
|
||||
Unique identifier for this provider.
|
||||
|
||||
Examples: 'static', 'oauth2', 'my_custom_auth'
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
Credential types this provider can manage.
|
||||
|
||||
Returns:
|
||||
List of CredentialType enums this provider supports
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Refresh the credential (e.g., use refresh_token to get new access_token).
|
||||
|
||||
This method should:
|
||||
1. Use existing credential data to obtain new values
|
||||
2. Update the credential object with new values
|
||||
3. Set appropriate expiration times
|
||||
4. Update last_refreshed timestamp
|
||||
|
||||
Args:
|
||||
credential: The credential to refresh
|
||||
|
||||
Returns:
|
||||
Updated credential with new values
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate that a credential is still working.
|
||||
|
||||
This might involve:
|
||||
- Checking expiration times
|
||||
- Making a test API call
|
||||
- Validating token signatures
|
||||
|
||||
Args:
|
||||
credential: The credential to validate
|
||||
|
||||
Returns:
|
||||
True if credential is valid, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Determine if a credential should be refreshed.
|
||||
|
||||
Default implementation: refresh if any key is expired or within
|
||||
5 minutes of expiry. Override for custom logic.
|
||||
|
||||
Args:
|
||||
credential: The credential to check
|
||||
|
||||
Returns:
|
||||
True if credential should be refreshed
|
||||
"""
|
||||
buffer = timedelta(minutes=5)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for key in credential.keys.values():
|
||||
if key.expires_at is not None:
|
||||
if key.expires_at <= now + buffer:
|
||||
return True
|
||||
return False
|
||||
|
||||
def revoke(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Revoke a credential (optional operation).
|
||||
|
||||
Not all providers support revocation. The default implementation
|
||||
logs a warning and returns False.
|
||||
|
||||
Args:
|
||||
credential: The credential to revoke
|
||||
|
||||
Returns:
|
||||
True if revocation succeeded, False otherwise
|
||||
"""
|
||||
logger.warning(f"Provider '{self.provider_id}' does not support revocation")
|
||||
return False
|
||||
|
||||
def can_handle(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if this provider can handle a credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to check
|
||||
|
||||
Returns:
|
||||
True if this provider can manage the credential
|
||||
"""
|
||||
return credential.credential_type in self.supported_types
|
||||
|
||||
|
||||
class StaticProvider(CredentialProvider):
|
||||
"""
|
||||
Provider for static credentials that never need refresh.
|
||||
|
||||
Use for simple API keys that don't expire, such as:
|
||||
- Brave Search API key
|
||||
- OpenAI API key
|
||||
- Basic auth credentials
|
||||
|
||||
Static credentials are always considered valid if they have at least one key.
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return "static"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.API_KEY, CredentialType.BASIC_AUTH, CredentialType.CUSTOM]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Static credentials don't need refresh.
|
||||
|
||||
Returns the credential unchanged.
|
||||
"""
|
||||
logger.debug(f"Static credential '{credential.id}' does not need refresh")
|
||||
return credential
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate that credential has at least one key with a value.
|
||||
|
||||
For static credentials, we can't verify the key works without
|
||||
making an API call, so we just check existence.
|
||||
"""
|
||||
if not credential.keys:
|
||||
return False
|
||||
|
||||
# Check at least one key has a non-empty value
|
||||
for key in credential.keys.values():
|
||||
try:
|
||||
value = key.get_secret_value()
|
||||
if value and value.strip():
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""Static credentials never need refresh."""
|
||||
return False
|
||||
|
||||
|
||||
class BearerTokenProvider(CredentialProvider):
|
||||
"""
|
||||
Provider for bearer tokens without refresh capability.
|
||||
|
||||
Use for JWTs or tokens that:
|
||||
- Have an expiration time
|
||||
- Cannot be refreshed (no refresh token)
|
||||
- Must be re-obtained when expired
|
||||
|
||||
This provider validates based on expiration time only.
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return "bearer_token"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.BEARER_TOKEN]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Bearer tokens without refresh capability cannot be refreshed.
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: Always, as refresh is not supported
|
||||
"""
|
||||
raise CredentialRefreshError(
|
||||
f"Bearer token '{credential.id}' cannot be refreshed. "
|
||||
"Obtain a new token and save it to the credential store."
|
||||
)
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate based on expiration time.
|
||||
|
||||
Returns True if token exists and is not expired.
|
||||
"""
|
||||
access_key = credential.keys.get("access_token") or credential.keys.get("token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
# Check if expired
|
||||
return not access_key.is_expired
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if token is expired or near expiration.
|
||||
|
||||
Note: Even though this returns True for expired tokens,
|
||||
refresh() will fail. This allows the store to know the
|
||||
credential needs attention.
|
||||
"""
|
||||
buffer = timedelta(minutes=5)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for key_name in ["access_token", "token"]:
|
||||
key = credential.keys.get(key_name)
|
||||
if key and key.expires_at:
|
||||
if key.expires_at <= now + buffer:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,518 @@
|
||||
"""
|
||||
Storage backends for the credential store.
|
||||
|
||||
This module provides abstract and concrete storage implementations:
|
||||
- CredentialStorage: Abstract base class
|
||||
- EncryptedFileStorage: Fernet-encrypted JSON files (default for production)
|
||||
- EnvVarStorage: Environment variable reading (backward compatibility)
|
||||
- InMemoryStorage: For testing
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .models import CredentialDecryptionError, CredentialKey, CredentialObject, CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialStorage(ABC):
|
||||
"""
|
||||
Abstract storage backend for credentials.
|
||||
|
||||
Implementations must provide save, load, delete, list_all, and exists methods.
|
||||
All implementations should handle serialization of SecretStr values securely.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save a credential to storage.
|
||||
|
||||
Args:
|
||||
credential: The credential object to save
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load a credential from storage.
|
||||
|
||||
Args:
|
||||
credential_id: The ID of the credential to load
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Delete a credential from storage.
|
||||
|
||||
Args:
|
||||
credential_id: The ID of the credential to delete
|
||||
|
||||
Returns:
|
||||
True if the credential existed and was deleted, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all(self) -> list[str]:
|
||||
"""
|
||||
List all credential IDs in storage.
|
||||
|
||||
Returns:
|
||||
List of credential IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if a credential exists in storage.
|
||||
|
||||
Args:
|
||||
credential_id: The ID to check
|
||||
|
||||
Returns:
|
||||
True if credential exists, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EncryptedFileStorage(CredentialStorage):
|
||||
"""
|
||||
Encrypted file-based credential storage.
|
||||
|
||||
Uses Fernet symmetric encryption (AES-128-CBC + HMAC) for at-rest encryption.
|
||||
Each credential is stored as a separate encrypted JSON file.
|
||||
|
||||
Directory structure:
|
||||
{base_path}/
|
||||
credentials/
|
||||
{credential_id}.enc # Encrypted credential JSON
|
||||
metadata/
|
||||
index.json # Index of all credentials (unencrypted)
|
||||
|
||||
The encryption key is read from the HIVE_CREDENTIAL_KEY environment variable.
|
||||
If not set, a new key is generated (and must be persisted for data recovery).
|
||||
|
||||
Example:
|
||||
storage = EncryptedFileStorage("~/.hive/credentials")
|
||||
storage.save(credential)
|
||||
credential = storage.load("brave_search")
|
||||
"""
|
||||
|
||||
DEFAULT_PATH = "~/.hive/credentials"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_path: str | Path | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
key_env_var: str = "HIVE_CREDENTIAL_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize encrypted storage.
|
||||
|
||||
Args:
|
||||
base_path: Directory for credential files. Defaults to ~/.hive/credentials.
|
||||
encryption_key: 32-byte Fernet key. If None, reads from env var.
|
||||
key_env_var: Environment variable containing encryption key
|
||||
"""
|
||||
try:
|
||||
from cryptography.fernet import Fernet
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Encrypted storage requires 'cryptography'. Install with: pip install cryptography"
|
||||
) from e
|
||||
|
||||
self.base_path = Path(base_path or self.DEFAULT_PATH).expanduser()
|
||||
self._ensure_dirs()
|
||||
self._key_env_var = key_env_var
|
||||
|
||||
# Get or generate encryption key
|
||||
if encryption_key:
|
||||
self._key = encryption_key
|
||||
else:
|
||||
key_str = os.environ.get(key_env_var)
|
||||
if key_str:
|
||||
self._key = key_str.encode()
|
||||
else:
|
||||
# Generate new key
|
||||
self._key = Fernet.generate_key()
|
||||
logger.warning(
|
||||
f"Generated new encryption key. To persist credentials across restarts, "
|
||||
f"set {key_env_var}={self._key.decode()}"
|
||||
)
|
||||
|
||||
self._fernet = Fernet(self._key)
|
||||
|
||||
def _ensure_dirs(self) -> None:
|
||||
"""Create directory structure."""
|
||||
(self.base_path / "credentials").mkdir(parents=True, exist_ok=True)
|
||||
(self.base_path / "metadata").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _cred_path(self, credential_id: str) -> Path:
|
||||
"""Get the file path for a credential."""
|
||||
# Sanitize credential_id to prevent path traversal
|
||||
safe_id = credential_id.replace("/", "_").replace("\\", "_").replace("..", "_")
|
||||
return self.base_path / "credentials" / f"{safe_id}.enc"
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Encrypt and save credential."""
|
||||
# Serialize credential
|
||||
data = self._serialize_credential(credential)
|
||||
json_bytes = json.dumps(data, default=str).encode()
|
||||
|
||||
# Encrypt
|
||||
encrypted = self._fernet.encrypt(json_bytes)
|
||||
|
||||
# Write to file
|
||||
cred_path = self._cred_path(credential.id)
|
||||
with open(cred_path, "wb") as f:
|
||||
f.write(encrypted)
|
||||
|
||||
# Update index
|
||||
self._update_index(credential.id, "save", credential.credential_type.value)
|
||||
logger.debug(f"Saved encrypted credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load and decrypt credential."""
|
||||
cred_path = self._cred_path(credential_id)
|
||||
if not cred_path.exists():
|
||||
return None
|
||||
|
||||
# Read encrypted data
|
||||
with open(cred_path, "rb") as f:
|
||||
encrypted = f.read()
|
||||
|
||||
# Decrypt
|
||||
try:
|
||||
json_bytes = self._fernet.decrypt(encrypted)
|
||||
data = json.loads(json_bytes.decode())
|
||||
except Exception as e:
|
||||
raise CredentialDecryptionError(
|
||||
f"Failed to decrypt credential '{credential_id}': {e}"
|
||||
) from e
|
||||
|
||||
# Deserialize
|
||||
return self._deserialize_credential(data)
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete a credential file."""
|
||||
cred_path = self._cred_path(credential_id)
|
||||
if cred_path.exists():
|
||||
cred_path.unlink()
|
||||
self._update_index(credential_id, "delete")
|
||||
logger.debug(f"Deleted credential '{credential_id}'")
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credential IDs."""
|
||||
index_path = self.base_path / "metadata" / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
return list(index.get("credentials", {}).keys())
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists."""
|
||||
return self._cred_path(credential_id).exists()
|
||||
|
||||
def _serialize_credential(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""Convert credential to JSON-serializable dict, extracting secret values."""
|
||||
data = credential.model_dump(mode="json")
|
||||
|
||||
# Extract actual secret values from SecretStr
|
||||
for key_name, key_data in data.get("keys", {}).items():
|
||||
if "value" in key_data:
|
||||
# SecretStr serializes as "**********", need actual value
|
||||
actual_key = credential.keys.get(key_name)
|
||||
if actual_key:
|
||||
key_data["value"] = actual_key.get_secret_value()
|
||||
|
||||
return data
|
||||
|
||||
def _deserialize_credential(self, data: dict[str, Any]) -> CredentialObject:
|
||||
"""Reconstruct credential from dict, wrapping values in SecretStr."""
|
||||
# Convert plain values back to SecretStr
|
||||
for key_data in data.get("keys", {}).values():
|
||||
if "value" in key_data and isinstance(key_data["value"], str):
|
||||
key_data["value"] = SecretStr(key_data["value"])
|
||||
|
||||
return CredentialObject.model_validate(data)
|
||||
|
||||
def _update_index(
|
||||
self,
|
||||
credential_id: str,
|
||||
operation: str,
|
||||
credential_type: str | None = None,
|
||||
) -> None:
|
||||
"""Update the metadata index."""
|
||||
index_path = self.base_path / "metadata" / "index.json"
|
||||
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
else:
|
||||
index = {"credentials": {}, "version": "1.0"}
|
||||
|
||||
if operation == "save":
|
||||
index["credentials"][credential_id] = {
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
"type": credential_type,
|
||||
}
|
||||
elif operation == "delete":
|
||||
index["credentials"].pop(credential_id, None)
|
||||
|
||||
index["last_modified"] = datetime.now(UTC).isoformat()
|
||||
|
||||
with open(index_path, "w") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
|
||||
|
||||
class EnvVarStorage(CredentialStorage):
|
||||
"""
|
||||
Environment variable-based storage for backward compatibility.
|
||||
|
||||
Maps credential IDs to environment variable patterns.
|
||||
Supports hot-reload from .env files using python-dotenv.
|
||||
|
||||
This storage is READ-ONLY - credentials cannot be saved at runtime.
|
||||
|
||||
Example:
|
||||
storage = EnvVarStorage(
|
||||
env_mapping={"brave_search": "BRAVE_SEARCH_API_KEY"},
|
||||
dotenv_path=Path(".env")
|
||||
)
|
||||
credential = storage.load("brave_search")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_mapping: dict[str, str] | None = None,
|
||||
dotenv_path: Path | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize env var storage.
|
||||
|
||||
Args:
|
||||
env_mapping: Map of credential_id -> env_var_name
|
||||
e.g., {"brave_search": "BRAVE_SEARCH_API_KEY"}
|
||||
If not provided, uses {CREDENTIAL_ID}_API_KEY pattern
|
||||
dotenv_path: Path to .env file for hot-reload support
|
||||
"""
|
||||
self._env_mapping = env_mapping or {}
|
||||
self._dotenv_path = dotenv_path or Path.cwd() / ".env"
|
||||
|
||||
def _get_env_var_name(self, credential_id: str) -> str:
|
||||
"""Get the environment variable name for a credential."""
|
||||
if credential_id in self._env_mapping:
|
||||
return self._env_mapping[credential_id]
|
||||
# Default pattern: CREDENTIAL_ID_API_KEY
|
||||
return f"{credential_id.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
def _read_env_value(self, env_var: str) -> str | None:
|
||||
"""Read value from env var or .env file."""
|
||||
# Check os.environ first (takes precedence)
|
||||
value = os.environ.get(env_var)
|
||||
if value:
|
||||
return value
|
||||
|
||||
# Fallback: read from .env file (hot-reload)
|
||||
if self._dotenv_path.exists():
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
|
||||
values = dotenv_values(self._dotenv_path)
|
||||
return values.get(env_var)
|
||||
except ImportError:
|
||||
logger.debug("python-dotenv not installed, skipping .env file")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Cannot save to environment variables at runtime."""
|
||||
raise NotImplementedError(
|
||||
"EnvVarStorage is read-only. Set environment variables "
|
||||
"externally or use EncryptedFileStorage."
|
||||
)
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from environment variable."""
|
||||
env_var = self._get_env_var_name(credential_id)
|
||||
value = self._read_env_value(env_var)
|
||||
|
||||
if not value:
|
||||
return None
|
||||
|
||||
return CredentialObject(
|
||||
id=credential_id,
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr(value))},
|
||||
description=f"Loaded from {env_var}",
|
||||
)
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Cannot delete environment variables at runtime."""
|
||||
raise NotImplementedError(
|
||||
"EnvVarStorage is read-only. Unset environment variables externally."
|
||||
)
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List credentials that are available in environment."""
|
||||
available = []
|
||||
|
||||
# Check mapped credentials
|
||||
for cred_id in self._env_mapping.keys():
|
||||
if self.exists(cred_id):
|
||||
available.append(cred_id)
|
||||
|
||||
return available
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential is available in environment."""
|
||||
env_var = self._get_env_var_name(credential_id)
|
||||
return self._read_env_value(env_var) is not None
|
||||
|
||||
def add_mapping(self, credential_id: str, env_var: str) -> None:
|
||||
"""
|
||||
Add a credential ID to environment variable mapping.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
env_var: The environment variable name
|
||||
"""
|
||||
self._env_mapping[credential_id] = env_var
|
||||
|
||||
|
||||
class InMemoryStorage(CredentialStorage):
|
||||
"""
|
||||
In-memory storage for testing.
|
||||
|
||||
Credentials are stored in a dictionary and lost when the process exits.
|
||||
|
||||
Example:
|
||||
storage = InMemoryStorage()
|
||||
storage.save(credential)
|
||||
credential = storage.load("test_cred")
|
||||
"""
|
||||
|
||||
def __init__(self, initial_data: dict[str, CredentialObject] | None = None):
|
||||
"""
|
||||
Initialize in-memory storage.
|
||||
|
||||
Args:
|
||||
initial_data: Optional dict of credential_id -> CredentialObject
|
||||
"""
|
||||
self._data: dict[str, CredentialObject] = initial_data or {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Save credential to memory."""
|
||||
self._data[credential.id] = credential
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from memory."""
|
||||
return self._data.get(credential_id)
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete credential from memory."""
|
||||
if credential_id in self._data:
|
||||
del self._data[credential_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credential IDs."""
|
||||
return list(self._data.keys())
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists."""
|
||||
return credential_id in self._data
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all credentials."""
|
||||
self._data.clear()
|
||||
|
||||
|
||||
class CompositeStorage(CredentialStorage):
|
||||
"""
|
||||
Composite storage that reads from multiple backends.
|
||||
|
||||
Useful for layering storages, e.g., encrypted file with env var fallback:
|
||||
- Writes go to the primary storage
|
||||
- Reads check primary first, then fallback storages
|
||||
|
||||
Example:
|
||||
storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage("~/.hive/credentials"),
|
||||
fallbacks=[EnvVarStorage({"brave_search": "BRAVE_SEARCH_API_KEY"})]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
primary: CredentialStorage,
|
||||
fallbacks: list[CredentialStorage] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize composite storage.
|
||||
|
||||
Args:
|
||||
primary: Primary storage for writes and first read attempt
|
||||
fallbacks: List of fallback storages to check if primary doesn't have credential
|
||||
"""
|
||||
self._primary = primary
|
||||
self._fallbacks = fallbacks or []
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Save to primary storage."""
|
||||
self._primary.save(credential)
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load from primary, then fallbacks."""
|
||||
# Try primary first
|
||||
credential = self._primary.load(credential_id)
|
||||
if credential is not None:
|
||||
return credential
|
||||
|
||||
# Try fallbacks
|
||||
for fallback in self._fallbacks:
|
||||
credential = fallback.load(credential_id)
|
||||
if credential is not None:
|
||||
return credential
|
||||
|
||||
return None
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete from primary storage only."""
|
||||
return self._primary.delete(credential_id)
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List credentials from all storages."""
|
||||
all_ids = set(self._primary.list_all())
|
||||
for fallback in self._fallbacks:
|
||||
all_ids.update(fallback.list_all())
|
||||
return list(all_ids)
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists in any storage."""
|
||||
if self._primary.exists(credential_id):
|
||||
return True
|
||||
return any(fallback.exists(credential_id) for fallback in self._fallbacks)
|
||||
@@ -0,0 +1,708 @@
|
||||
"""
|
||||
Main credential store orchestrating storage, providers, and template resolution.
|
||||
|
||||
The CredentialStore is the primary interface for credential management, providing:
|
||||
- Multi-backend storage (file, env, vault)
|
||||
- Provider-based lifecycle management (refresh, validate)
|
||||
- Template resolution for {{cred.key}} patterns
|
||||
- Caching with TTL for performance
|
||||
- Thread-safe operations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .models import (
|
||||
CredentialKey,
|
||||
CredentialObject,
|
||||
CredentialRefreshError,
|
||||
CredentialUsageSpec,
|
||||
)
|
||||
from .provider import CredentialProvider, StaticProvider
|
||||
from .storage import CredentialStorage, EnvVarStorage, InMemoryStorage
|
||||
from .template import TemplateResolver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialStore:
|
||||
"""
|
||||
Main credential store orchestrating storage, providers, and template resolution.
|
||||
|
||||
Features:
|
||||
- Multi-backend storage (file, env, vault)
|
||||
- Provider-based lifecycle management (refresh, validate)
|
||||
- Template resolution for {{cred.key}} patterns
|
||||
- Caching with TTL for performance
|
||||
- Thread-safe operations
|
||||
|
||||
Usage:
|
||||
# Basic usage
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage("~/.hive/credentials"),
|
||||
providers=[OAuth2Provider(), StaticProvider()]
|
||||
)
|
||||
|
||||
# Get a credential
|
||||
cred = store.get_credential("github_oauth")
|
||||
|
||||
# Resolve templates in headers
|
||||
headers = store.resolve_headers({
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}"
|
||||
})
|
||||
|
||||
# Register a tool's credential requirements
|
||||
store.register_usage(CredentialUsageSpec(
|
||||
credential_id="brave_search",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Subscription-Token": "{{brave_search.api_key}}"}
|
||||
))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: CredentialStorage | None = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
cache_ttl_seconds: int = 300,
|
||||
auto_refresh: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the credential store.
|
||||
|
||||
Args:
|
||||
storage: Storage backend. Defaults to EnvVarStorage for compatibility.
|
||||
providers: List of credential providers. Defaults to [StaticProvider()].
|
||||
cache_ttl_seconds: How long to cache credentials in memory (default: 5 minutes).
|
||||
auto_refresh: Whether to auto-refresh expired credentials on access.
|
||||
"""
|
||||
self._storage = storage or EnvVarStorage()
|
||||
self._providers: dict[str, CredentialProvider] = {}
|
||||
self._usage_specs: dict[str, CredentialUsageSpec] = {}
|
||||
|
||||
# Cache: credential_id -> (CredentialObject, cached_at)
|
||||
self._cache: dict[str, tuple[CredentialObject, datetime]] = {}
|
||||
self._cache_ttl = cache_ttl_seconds
|
||||
self._lock = threading.RLock()
|
||||
|
||||
self._auto_refresh = auto_refresh
|
||||
|
||||
# Register providers
|
||||
for provider in providers or [StaticProvider()]:
|
||||
self.register_provider(provider)
|
||||
|
||||
# Template resolver
|
||||
self._resolver = TemplateResolver(self)
|
||||
|
||||
# --- Provider Management ---
|
||||
|
||||
def register_provider(self, provider: CredentialProvider) -> None:
|
||||
"""
|
||||
Register a credential provider.
|
||||
|
||||
Args:
|
||||
provider: The provider to register
|
||||
"""
|
||||
self._providers[provider.provider_id] = provider
|
||||
logger.debug(f"Registered credential provider: {provider.provider_id}")
|
||||
|
||||
def get_provider(self, provider_id: str) -> CredentialProvider | None:
|
||||
"""
|
||||
Get a provider by ID.
|
||||
|
||||
Args:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
The provider if found, None otherwise
|
||||
"""
|
||||
return self._providers.get(provider_id)
|
||||
|
||||
def get_provider_for_credential(
|
||||
self, credential: CredentialObject
|
||||
) -> CredentialProvider | None:
|
||||
"""
|
||||
Get the appropriate provider for a credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to find a provider for
|
||||
|
||||
Returns:
|
||||
The provider if found, None otherwise
|
||||
"""
|
||||
# First, check if credential specifies a provider
|
||||
if credential.provider_id:
|
||||
provider = self._providers.get(credential.provider_id)
|
||||
if provider:
|
||||
return provider
|
||||
|
||||
# Fall back to finding a provider that supports this type
|
||||
for provider in self._providers.values():
|
||||
if provider.can_handle(credential):
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
# --- Usage Spec Management ---
|
||||
|
||||
def register_usage(self, spec: CredentialUsageSpec) -> None:
|
||||
"""
|
||||
Register how a tool uses credentials.
|
||||
|
||||
Args:
|
||||
spec: The usage specification
|
||||
"""
|
||||
self._usage_specs[spec.credential_id] = spec
|
||||
|
||||
def get_usage_spec(self, credential_id: str) -> CredentialUsageSpec | None:
|
||||
"""
|
||||
Get the usage spec for a credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
The usage spec if registered, None otherwise
|
||||
"""
|
||||
return self._usage_specs.get(credential_id)
|
||||
|
||||
# --- Credential Access ---
|
||||
|
||||
def get_credential(
|
||||
self,
|
||||
credential_id: str,
|
||||
refresh_if_needed: bool = True,
|
||||
) -> CredentialObject | None:
|
||||
"""
|
||||
Get a credential by ID.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
refresh_if_needed: If True, refresh expired credentials
|
||||
|
||||
Returns:
|
||||
CredentialObject or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
# Check cache
|
||||
cached = self._get_from_cache(credential_id)
|
||||
if cached is not None:
|
||||
if refresh_if_needed and self._should_refresh(cached):
|
||||
return self._refresh_credential(cached)
|
||||
return cached
|
||||
|
||||
# Load from storage
|
||||
credential = self._storage.load(credential_id)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
# Refresh if needed
|
||||
if refresh_if_needed and self._should_refresh(credential):
|
||||
credential = self._refresh_credential(credential)
|
||||
|
||||
# Cache
|
||||
self._add_to_cache(credential)
|
||||
|
||||
return credential
|
||||
|
||||
def get_key(self, credential_id: str, key_name: str) -> str | None:
|
||||
"""
|
||||
Convenience method to get a specific key value.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
key_name: The key within the credential
|
||||
|
||||
Returns:
|
||||
The key value or None if not found
|
||||
"""
|
||||
credential = self.get_credential(credential_id)
|
||||
if credential is None:
|
||||
return None
|
||||
return credential.get_key(key_name)
|
||||
|
||||
def get(self, credential_id: str) -> str | None:
|
||||
"""
|
||||
Legacy compatibility: get the primary key value.
|
||||
|
||||
For single-key credentials, returns that key.
|
||||
For multi-key, returns 'value', 'api_key', or 'access_token'.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
The primary key value or None
|
||||
"""
|
||||
credential = self.get_credential(credential_id)
|
||||
if credential is None:
|
||||
return None
|
||||
return credential.get_default_key()
|
||||
|
||||
# --- Template Resolution ---
|
||||
|
||||
def resolve(self, template: str) -> str:
|
||||
"""
|
||||
Resolve credential templates in a string.
|
||||
|
||||
Args:
|
||||
template: String containing {{cred.key}} patterns
|
||||
|
||||
Returns:
|
||||
Template with all references resolved
|
||||
|
||||
Example:
|
||||
>>> store.resolve("Bearer {{github.access_token}}")
|
||||
"Bearer ghp_xxxxxxxxxxxx"
|
||||
"""
|
||||
return self._resolver.resolve(template)
|
||||
|
||||
def resolve_headers(self, headers: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Resolve credential templates in headers dictionary.
|
||||
|
||||
Args:
|
||||
headers: Dict of header name to template value
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved
|
||||
|
||||
Example:
|
||||
>>> store.resolve_headers({
|
||||
... "Authorization": "Bearer {{github.access_token}}"
|
||||
... })
|
||||
{"Authorization": "Bearer ghp_xxx"}
|
||||
"""
|
||||
return self._resolver.resolve_headers(headers)
|
||||
|
||||
def resolve_params(self, params: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Resolve credential templates in query parameters dictionary.
|
||||
|
||||
Args:
|
||||
params: Dict of param name to template value
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved
|
||||
"""
|
||||
return self._resolver.resolve_params(params)
|
||||
|
||||
def resolve_for_usage(self, credential_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get resolved request kwargs for a registered usage spec.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
Dict with 'headers', 'params', etc. keys as appropriate
|
||||
|
||||
Raises:
|
||||
ValueError: If no usage spec is registered for the credential
|
||||
"""
|
||||
spec = self._usage_specs.get(credential_id)
|
||||
if spec is None:
|
||||
raise ValueError(f"No usage spec registered for '{credential_id}'")
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
if spec.headers:
|
||||
result["headers"] = self.resolve_headers(spec.headers)
|
||||
|
||||
if spec.query_params:
|
||||
result["params"] = self.resolve_params(spec.query_params)
|
||||
|
||||
if spec.body_fields:
|
||||
result["data"] = {key: self.resolve(value) for key, value in spec.body_fields.items()}
|
||||
|
||||
return result
|
||||
|
||||
# --- Credential Management ---
|
||||
|
||||
def save_credential(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save a credential to storage.
|
||||
|
||||
Args:
|
||||
credential: The credential to save
|
||||
"""
|
||||
with self._lock:
|
||||
self._storage.save(credential)
|
||||
self._add_to_cache(credential)
|
||||
logger.info(f"Saved credential '{credential.id}'")
|
||||
|
||||
def delete_credential(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Delete a credential from storage.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
True if the credential existed and was deleted
|
||||
"""
|
||||
with self._lock:
|
||||
self._remove_from_cache(credential_id)
|
||||
result = self._storage.delete(credential_id)
|
||||
if result:
|
||||
logger.info(f"Deleted credential '{credential_id}'")
|
||||
return result
|
||||
|
||||
def list_credentials(self) -> list[str]:
|
||||
"""
|
||||
List all available credential IDs.
|
||||
|
||||
Returns:
|
||||
List of credential IDs
|
||||
"""
|
||||
return self._storage.list_all()
|
||||
|
||||
def is_available(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if a credential is available.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
True if credential exists and is accessible
|
||||
"""
|
||||
return self.get_credential(credential_id, refresh_if_needed=False) is not None
|
||||
|
||||
# --- Validation ---
|
||||
|
||||
def validate_for_usage(self, credential_id: str) -> list[str]:
|
||||
"""
|
||||
Validate that a credential meets its usage spec requirements.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
List of missing keys or errors. Empty list if valid.
|
||||
"""
|
||||
spec = self._usage_specs.get(credential_id)
|
||||
if spec is None:
|
||||
return [] # No requirements registered
|
||||
|
||||
credential = self.get_credential(credential_id)
|
||||
if credential is None:
|
||||
return [f"Credential '{credential_id}' not found"]
|
||||
|
||||
errors = []
|
||||
for key_name in spec.required_keys:
|
||||
if not credential.has_key(key_name):
|
||||
errors.append(f"Missing required key '{key_name}'")
|
||||
|
||||
return errors
|
||||
|
||||
def validate_all(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Validate all registered usage specs.
|
||||
|
||||
Returns:
|
||||
Dict mapping credential_id to list of errors.
|
||||
Only includes credentials with errors.
|
||||
"""
|
||||
errors = {}
|
||||
for cred_id in self._usage_specs.keys():
|
||||
cred_errors = self.validate_for_usage(cred_id)
|
||||
if cred_errors:
|
||||
errors[cred_id] = cred_errors
|
||||
return errors
|
||||
|
||||
def validate_credential(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Validate a credential using its provider.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
True if credential is valid
|
||||
"""
|
||||
credential = self.get_credential(credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return False
|
||||
|
||||
provider = self.get_provider_for_credential(credential)
|
||||
if provider is None:
|
||||
# No provider, assume valid if has keys
|
||||
return bool(credential.keys)
|
||||
|
||||
return provider.validate(credential)
|
||||
|
||||
# --- Lifecycle Management ---
|
||||
|
||||
def _should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""Check if credential should be refreshed."""
|
||||
if not self._auto_refresh:
|
||||
return False
|
||||
|
||||
if not credential.auto_refresh:
|
||||
return False
|
||||
|
||||
provider = self.get_provider_for_credential(credential)
|
||||
if provider is None:
|
||||
return False
|
||||
|
||||
return provider.should_refresh(credential)
|
||||
|
||||
def _refresh_credential(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""Refresh a credential using its provider."""
|
||||
provider = self.get_provider_for_credential(credential)
|
||||
if provider is None:
|
||||
logger.warning(f"No provider found for credential '{credential.id}'")
|
||||
return credential
|
||||
|
||||
try:
|
||||
refreshed = provider.refresh(credential)
|
||||
refreshed.last_refreshed = datetime.now(UTC)
|
||||
|
||||
# Persist the refreshed credential
|
||||
self._storage.save(refreshed)
|
||||
self._add_to_cache(refreshed)
|
||||
|
||||
logger.info(f"Refreshed credential '{credential.id}'")
|
||||
return refreshed
|
||||
|
||||
except CredentialRefreshError as e:
|
||||
logger.error(f"Failed to refresh credential '{credential.id}': {e}")
|
||||
return credential
|
||||
|
||||
def refresh_credential(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Manually refresh a credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
The refreshed credential, or None if not found
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails
|
||||
"""
|
||||
credential = self.get_credential(credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
return self._refresh_credential(credential)
|
||||
|
||||
# --- Caching ---
|
||||
|
||||
def _get_from_cache(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Get credential from cache if not expired."""
|
||||
if credential_id not in self._cache:
|
||||
return None
|
||||
|
||||
credential, cached_at = self._cache[credential_id]
|
||||
age = (datetime.now(UTC) - cached_at).total_seconds()
|
||||
|
||||
if age > self._cache_ttl:
|
||||
del self._cache[credential_id]
|
||||
return None
|
||||
|
||||
return credential
|
||||
|
||||
def _add_to_cache(self, credential: CredentialObject) -> None:
|
||||
"""Add credential to cache."""
|
||||
self._cache[credential.id] = (credential, datetime.now(UTC))
|
||||
|
||||
def _remove_from_cache(self, credential_id: str) -> None:
|
||||
"""Remove credential from cache."""
|
||||
self._cache.pop(credential_id, None)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the credential cache."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
# --- Factory Methods ---
|
||||
|
||||
@classmethod
|
||||
def for_testing(
|
||||
cls,
|
||||
credentials: dict[str, dict[str, str]],
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store for testing with mock credentials.
|
||||
|
||||
Args:
|
||||
credentials: Dict mapping credential_id to {key_name: value}
|
||||
e.g., {"brave_search": {"api_key": "test-key"}}
|
||||
|
||||
Returns:
|
||||
CredentialStore with in-memory credentials
|
||||
|
||||
Example:
|
||||
store = CredentialStore.for_testing({
|
||||
"brave_search": {"api_key": "test-brave-key"},
|
||||
"github_oauth": {
|
||||
"access_token": "test-token",
|
||||
"refresh_token": "test-refresh"
|
||||
}
|
||||
})
|
||||
"""
|
||||
# Convert test data to CredentialObjects
|
||||
cred_objects: dict[str, CredentialObject] = {}
|
||||
|
||||
for cred_id, keys in credentials.items():
|
||||
cred_objects[cred_id] = CredentialObject(
|
||||
id=cred_id,
|
||||
keys={k: CredentialKey(name=k, value=SecretStr(v)) for k, v in keys.items()},
|
||||
)
|
||||
|
||||
return cls(
|
||||
storage=InMemoryStorage(cred_objects),
|
||||
auto_refresh=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def with_encrypted_storage(
|
||||
cls,
|
||||
base_path: str | None = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with encrypted file storage.
|
||||
|
||||
Args:
|
||||
base_path: Directory for credential files. Defaults to ~/.hive/credentials.
|
||||
providers: List of credential providers
|
||||
**kwargs: Additional arguments passed to CredentialStore
|
||||
|
||||
Returns:
|
||||
CredentialStore with EncryptedFileStorage
|
||||
"""
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
return cls(
|
||||
storage=EncryptedFileStorage(base_path),
|
||||
providers=providers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def with_env_storage(
|
||||
cls,
|
||||
env_mapping: dict[str, str] | None = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with environment variable storage.
|
||||
|
||||
Args:
|
||||
env_mapping: Map of credential_id -> env_var_name
|
||||
providers: List of credential providers
|
||||
**kwargs: Additional arguments passed to CredentialStore
|
||||
|
||||
Returns:
|
||||
CredentialStore with EnvVarStorage
|
||||
"""
|
||||
return cls(
|
||||
storage=EnvVarStorage(env_mapping),
|
||||
providers=providers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def with_aden_sync(
|
||||
cls,
|
||||
base_url: str = "https://api.adenhq.com",
|
||||
cache_ttl_seconds: int = 300,
|
||||
local_path: str | None = None,
|
||||
auto_sync: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with Aden server sync.
|
||||
|
||||
Automatically syncs OAuth2 tokens from the Aden authentication server.
|
||||
Falls back to local-only storage if ADEN_API_KEY is not set or Aden
|
||||
is unreachable.
|
||||
|
||||
Args:
|
||||
base_url: Aden server URL (default: https://api.adenhq.com)
|
||||
cache_ttl_seconds: How long to cache credentials locally (default: 5 min)
|
||||
local_path: Path for local credential storage (default: ~/.hive/credentials)
|
||||
auto_sync: Whether to sync all credentials on startup (default: True)
|
||||
**kwargs: Additional arguments passed to CredentialStore
|
||||
|
||||
Returns:
|
||||
CredentialStore configured with Aden sync
|
||||
|
||||
Example:
|
||||
# Simple usage - just set ADEN_API_KEY env var
|
||||
store = CredentialStore.with_aden_sync()
|
||||
|
||||
# Get HubSpot token (auto-refreshed via Aden)
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
# Determine local storage path
|
||||
if local_path is None:
|
||||
local_path = str(Path.home() / ".hive" / "credentials")
|
||||
|
||||
local_storage = EncryptedFileStorage(base_path=local_path)
|
||||
|
||||
# Check if Aden is configured
|
||||
api_key = os.environ.get("ADEN_API_KEY")
|
||||
if not api_key:
|
||||
logger.info("ADEN_API_KEY not set, using local-only credential storage")
|
||||
return cls(storage=local_storage, **kwargs)
|
||||
|
||||
# Try to setup Aden sync
|
||||
try:
|
||||
from .aden import (
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# Create Aden client
|
||||
client = AdenCredentialClient(AdenClientConfig(base_url=base_url))
|
||||
|
||||
# Create sync provider
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
# Use cached storage for offline resilience
|
||||
cached_storage = AdenCachedStorage(
|
||||
local_storage=local_storage,
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=cache_ttl_seconds,
|
||||
)
|
||||
|
||||
store = cls(
|
||||
storage=cached_storage,
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initial sync
|
||||
if auto_sync:
|
||||
synced = provider.sync_all(store)
|
||||
logger.info(f"Synced {synced} credentials from Aden server")
|
||||
|
||||
return store
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Aden components not available, using local storage")
|
||||
return cls(storage=local_storage, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup Aden sync: {e}. Using local storage.")
|
||||
return cls(storage=local_storage, **kwargs)
|
||||
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Template resolution system for credential injection.
|
||||
|
||||
This module handles {{cred.key}} patterns, enabling the bipartisan model
|
||||
where tools specify how credentials are used in HTTP requests.
|
||||
|
||||
Template Syntax:
|
||||
{{credential_id.key_name}} - Access specific key
|
||||
{{credential_id}} - Access default key (value, api_key, or access_token)
|
||||
|
||||
Examples:
|
||||
"Bearer {{github_oauth.access_token}}" -> "Bearer ghp_xxx"
|
||||
"X-API-Key: {{brave_search.api_key}}" -> "X-API-Key: BSAKxxx"
|
||||
"{{brave_search}}" -> "BSAKxxx" (uses default key)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import CredentialKeyNotFoundError, CredentialNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .store import CredentialStore
|
||||
|
||||
|
||||
class TemplateResolver:
|
||||
"""
|
||||
Resolves credential templates like {{cred.key}} into actual values.
|
||||
|
||||
Usage:
|
||||
resolver = TemplateResolver(credential_store)
|
||||
|
||||
# Resolve single template string
|
||||
auth_header = resolver.resolve("Bearer {{github_oauth.access_token}}")
|
||||
|
||||
# Resolve all headers at once
|
||||
headers = resolver.resolve_headers({
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}",
|
||||
"X-API-Key": "{{brave_search.api_key}}"
|
||||
})
|
||||
"""
|
||||
|
||||
# Matches {{credential_id}} or {{credential_id.key_name}}
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{([a-zA-Z0-9_-]+)(?:\.([a-zA-Z0-9_-]+))?\}\}")
|
||||
|
||||
def __init__(self, credential_store: CredentialStore):
|
||||
"""
|
||||
Initialize the template resolver.
|
||||
|
||||
Args:
|
||||
credential_store: The credential store to resolve references against
|
||||
"""
|
||||
self._store = credential_store
|
||||
|
||||
def resolve(self, template: str, fail_on_missing: bool = True) -> str:
|
||||
"""
|
||||
Resolve all credential references in a template string.
|
||||
|
||||
Args:
|
||||
template: String containing {{cred.key}} patterns
|
||||
fail_on_missing: If True, raise error on missing credentials
|
||||
|
||||
Returns:
|
||||
Template with all references replaced with actual values
|
||||
|
||||
Raises:
|
||||
CredentialNotFoundError: If credential doesn't exist and fail_on_missing=True
|
||||
CredentialKeyNotFoundError: If key doesn't exist in credential
|
||||
|
||||
Example:
|
||||
>>> resolver.resolve("Bearer {{github_oauth.access_token}}")
|
||||
"Bearer ghp_xxxxxxxxxxxx"
|
||||
"""
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
cred_id = match.group(1)
|
||||
key_name = match.group(2) # May be None
|
||||
|
||||
credential = self._store.get_credential(cred_id, refresh_if_needed=True)
|
||||
if credential is None:
|
||||
if fail_on_missing:
|
||||
raise CredentialNotFoundError(f"Credential '{cred_id}' not found")
|
||||
return match.group(0) # Return original template
|
||||
|
||||
# Get specific key or default
|
||||
if key_name:
|
||||
value = credential.get_key(key_name)
|
||||
if value is None:
|
||||
raise CredentialKeyNotFoundError(
|
||||
f"Key '{key_name}' not found in credential '{cred_id}'"
|
||||
)
|
||||
else:
|
||||
# Use default key
|
||||
value = credential.get_default_key()
|
||||
if value is None:
|
||||
raise CredentialKeyNotFoundError(f"Credential '{cred_id}' has no keys")
|
||||
|
||||
# Record usage
|
||||
credential.record_usage()
|
||||
|
||||
return value
|
||||
|
||||
return self.TEMPLATE_PATTERN.sub(replace_match, template)
|
||||
|
||||
def resolve_headers(
|
||||
self,
|
||||
header_templates: dict[str, str],
|
||||
fail_on_missing: bool = True,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Resolve templates in a headers dictionary.
|
||||
|
||||
Args:
|
||||
header_templates: Dict of header name to template value
|
||||
fail_on_missing: If True, raise error on missing credentials
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved to actual values
|
||||
|
||||
Example:
|
||||
>>> resolver.resolve_headers({
|
||||
... "Authorization": "Bearer {{github_oauth.access_token}}",
|
||||
... "X-API-Key": "{{brave_search.api_key}}"
|
||||
... })
|
||||
{"Authorization": "Bearer ghp_xxx", "X-API-Key": "BSAKxxx"}
|
||||
"""
|
||||
return {
|
||||
key: self.resolve(value, fail_on_missing) for key, value in header_templates.items()
|
||||
}
|
||||
|
||||
def resolve_params(
|
||||
self,
|
||||
param_templates: dict[str, str],
|
||||
fail_on_missing: bool = True,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Resolve templates in a query parameters dictionary.
|
||||
|
||||
Args:
|
||||
param_templates: Dict of param name to template value
|
||||
fail_on_missing: If True, raise error on missing credentials
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved to actual values
|
||||
"""
|
||||
return {key: self.resolve(value, fail_on_missing) for key, value in param_templates.items()}
|
||||
|
||||
def has_templates(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text contains any credential templates.
|
||||
|
||||
Args:
|
||||
text: String to check
|
||||
|
||||
Returns:
|
||||
True if text contains {{...}} patterns
|
||||
"""
|
||||
return bool(self.TEMPLATE_PATTERN.search(text))
|
||||
|
||||
def extract_references(self, text: str) -> list[tuple[str, str | None]]:
|
||||
"""
|
||||
Extract all credential references from text.
|
||||
|
||||
Args:
|
||||
text: String to extract references from
|
||||
|
||||
Returns:
|
||||
List of (credential_id, key_name) tuples.
|
||||
key_name is None if only credential_id was specified.
|
||||
|
||||
Example:
|
||||
>>> resolver.extract_references("{{github.token}} and {{brave_search.api_key}}")
|
||||
[("github", "token"), ("brave_search", "api_key")]
|
||||
"""
|
||||
return [(match.group(1), match.group(2)) for match in self.TEMPLATE_PATTERN.finditer(text)]
|
||||
|
||||
def validate_references(self, text: str) -> list[str]:
|
||||
"""
|
||||
Validate all credential references in text without resolving.
|
||||
|
||||
Args:
|
||||
text: String containing template references
|
||||
|
||||
Returns:
|
||||
List of error messages for invalid references.
|
||||
Empty list if all references are valid.
|
||||
"""
|
||||
errors = []
|
||||
references = self.extract_references(text)
|
||||
|
||||
for cred_id, key_name in references:
|
||||
credential = self._store.get_credential(cred_id, refresh_if_needed=False)
|
||||
|
||||
if credential is None:
|
||||
errors.append(f"Credential '{cred_id}' not found")
|
||||
continue
|
||||
|
||||
if key_name:
|
||||
if not credential.has_key(key_name):
|
||||
errors.append(f"Key '{key_name}' not found in credential '{cred_id}'")
|
||||
elif not credential.keys:
|
||||
errors.append(f"Credential '{cred_id}' has no keys")
|
||||
|
||||
return errors
|
||||
|
||||
def get_required_credentials(self, text: str) -> list[str]:
|
||||
"""
|
||||
Get list of credential IDs required by a template string.
|
||||
|
||||
Args:
|
||||
text: String containing template references
|
||||
|
||||
Returns:
|
||||
List of unique credential IDs referenced in the text
|
||||
"""
|
||||
references = self.extract_references(text)
|
||||
return list(dict.fromkeys(cred_id for cred_id, _ in references))
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for the credential store module."""
|
||||
@@ -0,0 +1,707 @@
|
||||
"""
|
||||
Comprehensive tests for the credential store module.
|
||||
|
||||
Tests cover:
|
||||
- Core models (CredentialObject, CredentialKey, CredentialUsageSpec)
|
||||
- Template resolution
|
||||
- Storage backends (InMemoryStorage, EnvVarStorage, EncryptedFileStorage)
|
||||
- Providers (StaticProvider, BearerTokenProvider)
|
||||
- Main CredentialStore
|
||||
- OAuth2 module
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from core.framework.credentials import (
|
||||
CompositeStorage,
|
||||
CredentialKey,
|
||||
CredentialKeyNotFoundError,
|
||||
CredentialNotFoundError,
|
||||
CredentialObject,
|
||||
CredentialStore,
|
||||
CredentialType,
|
||||
CredentialUsageSpec,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
InMemoryStorage,
|
||||
StaticProvider,
|
||||
TemplateResolver,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
class TestCredentialKey:
|
||||
"""Tests for CredentialKey model."""
|
||||
|
||||
def test_create_basic_key(self):
|
||||
"""Test creating a basic credential key."""
|
||||
key = CredentialKey(name="api_key", value=SecretStr("test-value"))
|
||||
assert key.name == "api_key"
|
||||
assert key.get_secret_value() == "test-value"
|
||||
assert key.expires_at is None
|
||||
assert not key.is_expired
|
||||
|
||||
def test_key_with_expiration(self):
|
||||
"""Test key with expiration time."""
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=future)
|
||||
assert not key.is_expired
|
||||
|
||||
def test_expired_key(self):
|
||||
"""Test that expired key is detected."""
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)
|
||||
assert key.is_expired
|
||||
|
||||
def test_key_with_metadata(self):
|
||||
"""Test key with metadata."""
|
||||
key = CredentialKey(
|
||||
name="token",
|
||||
value=SecretStr("xxx"),
|
||||
metadata={"client_id": "abc", "scope": "read"},
|
||||
)
|
||||
assert key.metadata["client_id"] == "abc"
|
||||
|
||||
|
||||
class TestCredentialObject:
|
||||
"""Tests for CredentialObject model."""
|
||||
|
||||
def test_create_simple_credential(self):
|
||||
"""Test creating a simple API key credential."""
|
||||
cred = CredentialObject(
|
||||
id="brave_search",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("test-key"))},
|
||||
)
|
||||
assert cred.id == "brave_search"
|
||||
assert cred.credential_type == CredentialType.API_KEY
|
||||
assert cred.get_key("api_key") == "test-key"
|
||||
|
||||
def test_create_multi_key_credential(self):
|
||||
"""Test creating a credential with multiple keys."""
|
||||
cred = CredentialObject(
|
||||
id="github_oauth",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(name="access_token", value=SecretStr("ghp_xxx")),
|
||||
"refresh_token": CredentialKey(name="refresh_token", value=SecretStr("ghr_xxx")),
|
||||
},
|
||||
)
|
||||
assert cred.get_key("access_token") == "ghp_xxx"
|
||||
assert cred.get_key("refresh_token") == "ghr_xxx"
|
||||
assert cred.get_key("nonexistent") is None
|
||||
|
||||
def test_set_key(self):
|
||||
"""Test setting a key on a credential."""
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
cred.set_key("new_key", "new_value")
|
||||
assert cred.get_key("new_key") == "new_value"
|
||||
|
||||
def test_set_key_with_expiration(self):
|
||||
"""Test setting a key with expiration."""
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
expires = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred.set_key("token", "xxx", expires_at=expires)
|
||||
assert cred.keys["token"].expires_at == expires
|
||||
|
||||
def test_needs_refresh(self):
|
||||
"""Test needs_refresh property."""
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
keys={"token": CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)},
|
||||
)
|
||||
assert cred.needs_refresh
|
||||
|
||||
def test_get_default_key(self):
|
||||
"""Test get_default_key returns appropriate default."""
|
||||
# With api_key
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("key-value"))},
|
||||
)
|
||||
assert cred.get_default_key() == "key-value"
|
||||
|
||||
# With access_token
|
||||
cred2 = CredentialObject(
|
||||
id="test",
|
||||
keys={
|
||||
"access_token": CredentialKey(name="access_token", value=SecretStr("token-value"))
|
||||
},
|
||||
)
|
||||
assert cred2.get_default_key() == "token-value"
|
||||
|
||||
def test_record_usage(self):
|
||||
"""Test recording credential usage."""
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
assert cred.use_count == 0
|
||||
assert cred.last_used is None
|
||||
|
||||
cred.record_usage()
|
||||
assert cred.use_count == 1
|
||||
assert cred.last_used is not None
|
||||
|
||||
|
||||
class TestCredentialUsageSpec:
|
||||
"""Tests for CredentialUsageSpec model."""
|
||||
|
||||
def test_create_usage_spec(self):
|
||||
"""Test creating a usage spec."""
|
||||
spec = CredentialUsageSpec(
|
||||
credential_id="brave_search",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Subscription-Token": "{{api_key}}"},
|
||||
)
|
||||
assert spec.credential_id == "brave_search"
|
||||
assert "api_key" in spec.required_keys
|
||||
assert "{{api_key}}" in spec.headers.values()
|
||||
|
||||
|
||||
class TestInMemoryStorage:
|
||||
"""Tests for InMemoryStorage."""
|
||||
|
||||
def test_save_and_load(self):
|
||||
"""Test saving and loading a credential."""
|
||||
storage = InMemoryStorage()
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
keys={"key": CredentialKey(name="key", value=SecretStr("value"))},
|
||||
)
|
||||
|
||||
storage.save(cred)
|
||||
loaded = storage.load("test")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == "test"
|
||||
assert loaded.get_key("key") == "value"
|
||||
|
||||
def test_load_nonexistent(self):
|
||||
"""Test loading a nonexistent credential."""
|
||||
storage = InMemoryStorage()
|
||||
assert storage.load("nonexistent") is None
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deleting a credential."""
|
||||
storage = InMemoryStorage()
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
storage.save(cred)
|
||||
|
||||
assert storage.delete("test")
|
||||
assert storage.load("test") is None
|
||||
assert not storage.delete("test")
|
||||
|
||||
def test_list_all(self):
|
||||
"""Test listing all credentials."""
|
||||
storage = InMemoryStorage()
|
||||
storage.save(CredentialObject(id="a", keys={}))
|
||||
storage.save(CredentialObject(id="b", keys={}))
|
||||
|
||||
ids = storage.list_all()
|
||||
assert "a" in ids
|
||||
assert "b" in ids
|
||||
|
||||
def test_exists(self):
|
||||
"""Test checking if credential exists."""
|
||||
storage = InMemoryStorage()
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
|
||||
assert storage.exists("test")
|
||||
assert not storage.exists("nonexistent")
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing all credentials."""
|
||||
storage = InMemoryStorage()
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
storage.clear()
|
||||
|
||||
assert storage.list_all() == []
|
||||
|
||||
|
||||
class TestEnvVarStorage:
|
||||
"""Tests for EnvVarStorage."""
|
||||
|
||||
def test_load_from_env(self):
|
||||
"""Test loading credential from environment variable."""
|
||||
with patch.dict(os.environ, {"TEST_API_KEY": "test-value"}):
|
||||
storage = EnvVarStorage(env_mapping={"test": "TEST_API_KEY"})
|
||||
cred = storage.load("test")
|
||||
|
||||
assert cred is not None
|
||||
assert cred.get_key("api_key") == "test-value"
|
||||
|
||||
def test_load_nonexistent(self):
|
||||
"""Test loading when env var is not set."""
|
||||
storage = EnvVarStorage(env_mapping={"test": "NONEXISTENT_VAR"})
|
||||
assert storage.load("test") is None
|
||||
|
||||
def test_default_env_var_pattern(self):
|
||||
"""Test default env var naming pattern."""
|
||||
with patch.dict(os.environ, {"MY_SERVICE_API_KEY": "value"}):
|
||||
storage = EnvVarStorage()
|
||||
cred = storage.load("my_service")
|
||||
|
||||
assert cred is not None
|
||||
assert cred.get_key("api_key") == "value"
|
||||
|
||||
def test_save_raises(self):
|
||||
"""Test that save raises NotImplementedError."""
|
||||
storage = EnvVarStorage()
|
||||
with pytest.raises(NotImplementedError):
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
|
||||
def test_delete_raises(self):
|
||||
"""Test that delete raises NotImplementedError."""
|
||||
storage = EnvVarStorage()
|
||||
with pytest.raises(NotImplementedError):
|
||||
storage.delete("test")
|
||||
|
||||
|
||||
class TestEncryptedFileStorage:
|
||||
"""Tests for EncryptedFileStorage."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create a temporary directory for tests."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self, temp_dir):
|
||||
"""Create EncryptedFileStorage for tests."""
|
||||
return EncryptedFileStorage(temp_dir)
|
||||
|
||||
def test_save_and_load(self, storage):
|
||||
"""Test saving and loading encrypted credential."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("secret-value"))},
|
||||
)
|
||||
|
||||
storage.save(cred)
|
||||
loaded = storage.load("test")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == "test"
|
||||
assert loaded.get_key("api_key") == "secret-value"
|
||||
|
||||
def test_encryption_key_from_env(self, temp_dir):
|
||||
"""Test using encryption key from environment variable."""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
key = Fernet.generate_key().decode()
|
||||
with patch.dict(os.environ, {"HIVE_CREDENTIAL_KEY": key}):
|
||||
storage = EncryptedFileStorage(temp_dir)
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
storage.save(cred)
|
||||
|
||||
# Create new storage instance with same key
|
||||
storage2 = EncryptedFileStorage(temp_dir)
|
||||
loaded = storage2.load("test")
|
||||
assert loaded is not None
|
||||
assert loaded.get_key("k") == "v"
|
||||
|
||||
def test_list_all(self, storage):
|
||||
"""Test listing all credentials."""
|
||||
storage.save(CredentialObject(id="cred1", keys={}))
|
||||
storage.save(CredentialObject(id="cred2", keys={}))
|
||||
|
||||
ids = storage.list_all()
|
||||
assert "cred1" in ids
|
||||
assert "cred2" in ids
|
||||
|
||||
def test_delete(self, storage):
|
||||
"""Test deleting a credential."""
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
assert storage.delete("test")
|
||||
assert storage.load("test") is None
|
||||
|
||||
|
||||
class TestCompositeStorage:
|
||||
"""Tests for CompositeStorage."""
|
||||
|
||||
def test_read_from_primary(self):
|
||||
"""Test reading from primary storage."""
|
||||
primary = InMemoryStorage()
|
||||
primary.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("primary"))}
|
||||
)
|
||||
)
|
||||
|
||||
fallback = InMemoryStorage()
|
||||
fallback.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
|
||||
)
|
||||
)
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
cred = storage.load("test")
|
||||
|
||||
# Should get from primary
|
||||
assert cred.get_key("k") == "primary"
|
||||
|
||||
def test_fallback_when_not_in_primary(self):
|
||||
"""Test fallback when credential not in primary."""
|
||||
primary = InMemoryStorage()
|
||||
fallback = InMemoryStorage()
|
||||
fallback.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
|
||||
)
|
||||
)
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
cred = storage.load("test")
|
||||
|
||||
assert cred.get_key("k") == "fallback"
|
||||
|
||||
def test_write_to_primary_only(self):
|
||||
"""Test that writes go to primary only."""
|
||||
primary = InMemoryStorage()
|
||||
fallback = InMemoryStorage()
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
|
||||
assert primary.exists("test")
|
||||
assert not fallback.exists("test")
|
||||
|
||||
|
||||
class TestStaticProvider:
|
||||
"""Tests for StaticProvider."""
|
||||
|
||||
def test_provider_id(self):
|
||||
"""Test provider ID."""
|
||||
provider = StaticProvider()
|
||||
assert provider.provider_id == "static"
|
||||
|
||||
def test_supported_types(self):
|
||||
"""Test supported credential types."""
|
||||
provider = StaticProvider()
|
||||
assert CredentialType.API_KEY in provider.supported_types
|
||||
assert CredentialType.CUSTOM in provider.supported_types
|
||||
|
||||
def test_refresh_returns_unchanged(self):
|
||||
"""Test that refresh returns credential unchanged."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
|
||||
refreshed = provider.refresh(cred)
|
||||
assert refreshed.get_key("k") == "v"
|
||||
|
||||
def test_validate_with_keys(self):
|
||||
"""Test validation with keys present."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
|
||||
assert provider.validate(cred)
|
||||
|
||||
def test_validate_without_keys(self):
|
||||
"""Test validation without keys."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
|
||||
assert not provider.validate(cred)
|
||||
|
||||
def test_should_refresh(self):
|
||||
"""Test that static provider never needs refresh."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
|
||||
assert not provider.should_refresh(cred)
|
||||
|
||||
|
||||
class TestTemplateResolver:
|
||||
"""Tests for TemplateResolver."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
"""Create a test store with credentials."""
|
||||
return CredentialStore.for_testing(
|
||||
{
|
||||
"brave_search": {"api_key": "test-brave-key"},
|
||||
"github_oauth": {"access_token": "ghp_xxx", "refresh_token": "ghr_xxx"},
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def resolver(self, store):
|
||||
"""Create a resolver with the test store."""
|
||||
return TemplateResolver(store)
|
||||
|
||||
def test_resolve_simple(self, resolver):
|
||||
"""Test resolving a simple template."""
|
||||
result = resolver.resolve("Bearer {{github_oauth.access_token}}")
|
||||
assert result == "Bearer ghp_xxx"
|
||||
|
||||
def test_resolve_multiple(self, resolver):
|
||||
"""Test resolving multiple templates."""
|
||||
result = resolver.resolve("{{github_oauth.access_token}} and {{brave_search.api_key}}")
|
||||
assert "ghp_xxx" in result
|
||||
assert "test-brave-key" in result
|
||||
|
||||
def test_resolve_default_key(self, resolver):
|
||||
"""Test resolving credential without key specified."""
|
||||
result = resolver.resolve("Key: {{brave_search}}")
|
||||
assert "test-brave-key" in result
|
||||
|
||||
def test_resolve_headers(self, resolver):
|
||||
"""Test resolving headers dict."""
|
||||
headers = resolver.resolve_headers(
|
||||
{
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}",
|
||||
"X-API-Key": "{{brave_search.api_key}}",
|
||||
}
|
||||
)
|
||||
assert headers["Authorization"] == "Bearer ghp_xxx"
|
||||
assert headers["X-API-Key"] == "test-brave-key"
|
||||
|
||||
def test_resolve_missing_credential(self, resolver):
|
||||
"""Test error on missing credential."""
|
||||
with pytest.raises(CredentialNotFoundError):
|
||||
resolver.resolve("{{nonexistent.key}}")
|
||||
|
||||
def test_resolve_missing_key(self, resolver):
|
||||
"""Test error on missing key."""
|
||||
with pytest.raises(CredentialKeyNotFoundError):
|
||||
resolver.resolve("{{github_oauth.nonexistent}}")
|
||||
|
||||
def test_has_templates(self, resolver):
|
||||
"""Test detecting templates in text."""
|
||||
assert resolver.has_templates("{{cred.key}}")
|
||||
assert resolver.has_templates("Bearer {{token}}")
|
||||
assert not resolver.has_templates("no templates here")
|
||||
|
||||
def test_extract_references(self, resolver):
|
||||
"""Test extracting credential references."""
|
||||
refs = resolver.extract_references("{{github.token}} and {{brave.key}}")
|
||||
assert ("github", "token") in refs
|
||||
assert ("brave", "key") in refs
|
||||
|
||||
|
||||
class TestCredentialStore:
|
||||
"""Tests for CredentialStore."""
|
||||
|
||||
def test_for_testing_factory(self):
|
||||
"""Test creating store for testing."""
|
||||
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
|
||||
|
||||
assert store.get("test") == "value"
|
||||
assert store.get_key("test", "api_key") == "value"
|
||||
|
||||
def test_get_credential(self):
|
||||
"""Test getting a credential."""
|
||||
store = CredentialStore.for_testing({"test": {"key": "value"}})
|
||||
|
||||
cred = store.get_credential("test")
|
||||
assert cred is not None
|
||||
assert cred.get_key("key") == "value"
|
||||
|
||||
def test_get_nonexistent(self):
|
||||
"""Test getting nonexistent credential."""
|
||||
store = CredentialStore.for_testing({})
|
||||
assert store.get_credential("nonexistent") is None
|
||||
assert store.get("nonexistent") is None
|
||||
|
||||
def test_save_and_load(self):
|
||||
"""Test saving and loading a credential."""
|
||||
store = CredentialStore.for_testing({})
|
||||
|
||||
cred = CredentialObject(id="new", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
store.save_credential(cred)
|
||||
|
||||
loaded = store.get_credential("new")
|
||||
assert loaded is not None
|
||||
assert loaded.get_key("k") == "v"
|
||||
|
||||
def test_delete_credential(self):
|
||||
"""Test deleting a credential."""
|
||||
store = CredentialStore.for_testing({"test": {"k": "v"}})
|
||||
|
||||
assert store.delete_credential("test")
|
||||
assert store.get_credential("test") is None
|
||||
|
||||
def test_list_credentials(self):
|
||||
"""Test listing all credentials."""
|
||||
store = CredentialStore.for_testing({"a": {"k": "v"}, "b": {"k": "v"}})
|
||||
|
||||
ids = store.list_credentials()
|
||||
assert "a" in ids
|
||||
assert "b" in ids
|
||||
|
||||
def test_is_available(self):
|
||||
"""Test checking credential availability."""
|
||||
store = CredentialStore.for_testing({"test": {"k": "v"}})
|
||||
|
||||
assert store.is_available("test")
|
||||
assert not store.is_available("nonexistent")
|
||||
|
||||
def test_resolve_templates(self):
|
||||
"""Test template resolution through store."""
|
||||
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
|
||||
|
||||
result = store.resolve("Key: {{test.api_key}}")
|
||||
assert result == "Key: value"
|
||||
|
||||
def test_resolve_headers(self):
|
||||
"""Test resolving headers through store."""
|
||||
store = CredentialStore.for_testing({"test": {"token": "xxx"}})
|
||||
|
||||
headers = store.resolve_headers({"Authorization": "Bearer {{test.token}}"})
|
||||
assert headers["Authorization"] == "Bearer xxx"
|
||||
|
||||
def test_register_provider(self):
|
||||
"""Test registering a provider."""
|
||||
store = CredentialStore.for_testing({})
|
||||
provider = StaticProvider()
|
||||
|
||||
store.register_provider(provider)
|
||||
assert store.get_provider("static") is provider
|
||||
|
||||
def test_register_usage_spec(self):
|
||||
"""Test registering a usage spec."""
|
||||
store = CredentialStore.for_testing({})
|
||||
spec = CredentialUsageSpec(
|
||||
credential_id="test",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Key": "{{api_key}}"},
|
||||
)
|
||||
|
||||
store.register_usage(spec)
|
||||
assert store.get_usage_spec("test") is spec
|
||||
|
||||
def test_validate_for_usage(self):
|
||||
"""Test validating credential for usage spec."""
|
||||
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
|
||||
spec = CredentialUsageSpec(credential_id="test", required_keys=["api_key"])
|
||||
store.register_usage(spec)
|
||||
|
||||
errors = store.validate_for_usage("test")
|
||||
assert errors == []
|
||||
|
||||
def test_validate_for_usage_missing_key(self):
|
||||
"""Test validation with missing required key."""
|
||||
store = CredentialStore.for_testing({"test": {"other_key": "value"}})
|
||||
spec = CredentialUsageSpec(credential_id="test", required_keys=["api_key"])
|
||||
store.register_usage(spec)
|
||||
|
||||
errors = store.validate_for_usage("test")
|
||||
assert "api_key" in errors[0]
|
||||
|
||||
def test_caching(self):
|
||||
"""Test that credentials are cached."""
|
||||
storage = InMemoryStorage()
|
||||
store = CredentialStore(storage=storage, cache_ttl_seconds=60)
|
||||
|
||||
storage.save(
|
||||
CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
)
|
||||
|
||||
# First load
|
||||
store.get_credential("test")
|
||||
|
||||
# Delete from storage
|
||||
storage.delete("test")
|
||||
|
||||
# Should still get from cache
|
||||
cred2 = store.get_credential("test")
|
||||
assert cred2 is not None
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing the cache."""
|
||||
storage = InMemoryStorage()
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
store.get_credential("test") # Cache it
|
||||
|
||||
storage.delete("test")
|
||||
store.clear_cache()
|
||||
|
||||
# Should not find in cache now
|
||||
assert store.get_credential("test") is None
|
||||
|
||||
|
||||
class TestOAuth2Module:
|
||||
"""Tests for OAuth2 module."""
|
||||
|
||||
def test_oauth2_token_from_response(self):
|
||||
"""Test creating OAuth2Token from token response."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Token
|
||||
|
||||
response = {
|
||||
"access_token": "xxx",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "yyy",
|
||||
"scope": "read write",
|
||||
}
|
||||
|
||||
token = OAuth2Token.from_token_response(response)
|
||||
assert token.access_token == "xxx"
|
||||
assert token.token_type == "Bearer"
|
||||
assert token.refresh_token == "yyy"
|
||||
assert token.scope == "read write"
|
||||
assert token.expires_at is not None
|
||||
|
||||
def test_token_is_expired(self):
|
||||
"""Test token expiration check."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Token
|
||||
|
||||
# Not expired
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
token = OAuth2Token(access_token="xxx", expires_at=future)
|
||||
assert not token.is_expired
|
||||
|
||||
# Expired
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
expired_token = OAuth2Token(access_token="xxx", expires_at=past)
|
||||
assert expired_token.is_expired
|
||||
|
||||
def test_token_can_refresh(self):
|
||||
"""Test token refresh capability check."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Token
|
||||
|
||||
with_refresh = OAuth2Token(access_token="xxx", refresh_token="yyy")
|
||||
assert with_refresh.can_refresh
|
||||
|
||||
without_refresh = OAuth2Token(access_token="xxx")
|
||||
assert not without_refresh.can_refresh
|
||||
|
||||
def test_oauth2_config_validation(self):
|
||||
"""Test OAuth2Config validation."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Config, TokenPlacement
|
||||
|
||||
# Valid config
|
||||
config = OAuth2Config(
|
||||
token_url="https://example.com/token", client_id="id", client_secret="secret"
|
||||
)
|
||||
assert config.token_url == "https://example.com/token"
|
||||
|
||||
# Missing token_url
|
||||
with pytest.raises(ValueError):
|
||||
OAuth2Config(token_url="")
|
||||
|
||||
# HEADER_CUSTOM without custom_header_name
|
||||
with pytest.raises(ValueError):
|
||||
OAuth2Config(
|
||||
token_url="https://example.com/token",
|
||||
token_placement=TokenPlacement.HEADER_CUSTOM,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
HashiCorp Vault integration for the credential store.
|
||||
|
||||
This module provides enterprise-grade secret management through
|
||||
HashiCorp Vault integration.
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.vault import HashiCorpVaultStorage
|
||||
|
||||
# Configure Vault storage
|
||||
storage = HashiCorpVaultStorage(
|
||||
url="https://vault.example.com:8200",
|
||||
# token read from VAULT_TOKEN env var
|
||||
mount_point="secret",
|
||||
path_prefix="hive/agents/prod"
|
||||
)
|
||||
|
||||
# Create credential store with Vault backend
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Use normally - credentials are stored in Vault
|
||||
credential = store.get_credential("my_api")
|
||||
|
||||
Requirements:
|
||||
pip install hvac
|
||||
|
||||
Authentication:
|
||||
Set the VAULT_TOKEN environment variable or pass the token directly:
|
||||
|
||||
export VAULT_TOKEN="hvs.xxxxxxxxxxxxx"
|
||||
|
||||
For production, consider using Vault auth methods:
|
||||
- Kubernetes auth
|
||||
- AppRole auth
|
||||
- AWS IAM auth
|
||||
|
||||
Vault Configuration:
|
||||
Ensure KV v2 secrets engine is enabled:
|
||||
|
||||
vault secrets enable -path=secret kv-v2
|
||||
|
||||
Grant appropriate policies:
|
||||
|
||||
path "secret/data/hive/credentials/*" {
|
||||
capabilities = ["create", "read", "update", "delete", "list"]
|
||||
}
|
||||
path "secret/metadata/hive/credentials/*" {
|
||||
capabilities = ["list", "delete"]
|
||||
}
|
||||
"""
|
||||
|
||||
from .hashicorp import HashiCorpVaultStorage
|
||||
|
||||
__all__ = ["HashiCorpVaultStorage"]
|
||||
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
HashiCorp Vault storage adapter.
|
||||
|
||||
Provides integration with HashiCorp Vault for enterprise secret management.
|
||||
Requires the 'hvac' package: pip install hvac
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from ..models import CredentialKey, CredentialObject, CredentialType
|
||||
from ..storage import CredentialStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HashiCorpVaultStorage(CredentialStorage):
|
||||
"""
|
||||
HashiCorp Vault storage adapter.
|
||||
|
||||
Features:
|
||||
- KV v2 secrets engine support
|
||||
- Namespace support (Enterprise)
|
||||
- Automatic secret versioning
|
||||
- Audit logging via Vault
|
||||
|
||||
The adapter stores credentials in Vault's KV v2 secrets engine with
|
||||
the following structure:
|
||||
|
||||
{mount_point}/data/{path_prefix}/{credential_id}
|
||||
└── data:
|
||||
├── _type: "oauth2"
|
||||
├── access_token: "xxx"
|
||||
├── refresh_token: "yyy"
|
||||
├── _expires_access_token: "2024-01-26T12:00:00"
|
||||
└── _provider_id: "oauth2"
|
||||
|
||||
Example:
|
||||
storage = HashiCorpVaultStorage(
|
||||
url="https://vault.example.com:8200",
|
||||
token="hvs.xxx", # Or use VAULT_TOKEN env var
|
||||
mount_point="secret",
|
||||
path_prefix="hive/credentials"
|
||||
)
|
||||
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Credentials are now stored in Vault
|
||||
store.save_credential(credential)
|
||||
credential = store.get_credential("my_api")
|
||||
|
||||
Authentication:
|
||||
The adapter uses token-based authentication. The token can be provided:
|
||||
1. Directly via the 'token' parameter
|
||||
2. Via the VAULT_TOKEN environment variable
|
||||
|
||||
For production, consider using:
|
||||
- Kubernetes auth method
|
||||
- AppRole auth method
|
||||
- AWS IAM auth method
|
||||
|
||||
Requirements:
|
||||
pip install hvac
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
token: str | None = None,
|
||||
mount_point: str = "secret",
|
||||
path_prefix: str = "hive/credentials",
|
||||
namespace: str | None = None,
|
||||
verify_ssl: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize Vault storage.
|
||||
|
||||
Args:
|
||||
url: Vault server URL (e.g., https://vault.example.com:8200)
|
||||
token: Vault token. If None, reads from VAULT_TOKEN env var
|
||||
mount_point: KV secrets engine mount point (default: "secret")
|
||||
path_prefix: Path prefix for all credentials
|
||||
namespace: Vault namespace (Enterprise feature)
|
||||
verify_ssl: Whether to verify SSL certificates
|
||||
|
||||
Raises:
|
||||
ImportError: If hvac is not installed
|
||||
ValueError: If authentication fails
|
||||
"""
|
||||
try:
|
||||
import hvac
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"HashiCorp Vault support requires 'hvac'. Install with: pip install hvac"
|
||||
) from e
|
||||
|
||||
self._url = url
|
||||
self._token = token or os.environ.get("VAULT_TOKEN")
|
||||
self._mount = mount_point
|
||||
self._prefix = path_prefix
|
||||
self._namespace = namespace
|
||||
|
||||
if not self._token:
|
||||
raise ValueError(
|
||||
"Vault token required. Set VAULT_TOKEN env var or pass token parameter."
|
||||
)
|
||||
|
||||
self._client = hvac.Client(
|
||||
url=url,
|
||||
token=self._token,
|
||||
namespace=namespace,
|
||||
verify=verify_ssl,
|
||||
)
|
||||
|
||||
if not self._client.is_authenticated():
|
||||
raise ValueError("Vault authentication failed. Check token and server URL.")
|
||||
|
||||
logger.info(f"Connected to HashiCorp Vault at {url}")
|
||||
|
||||
def _path(self, credential_id: str) -> str:
|
||||
"""Build Vault path for credential."""
|
||||
# Sanitize credential_id
|
||||
safe_id = credential_id.replace("/", "_").replace("\\", "_")
|
||||
return f"{self._prefix}/{safe_id}"
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Save credential to Vault KV v2."""
|
||||
path = self._path(credential.id)
|
||||
data = self._serialize_for_vault(credential)
|
||||
|
||||
try:
|
||||
self._client.secrets.kv.v2.create_or_update_secret(
|
||||
path=path,
|
||||
secret=data,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
logger.debug(f"Saved credential '{credential.id}' to Vault at {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save credential '{credential.id}' to Vault: {e}")
|
||||
raise
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from Vault."""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
data = response["data"]["data"]
|
||||
return self._deserialize_from_vault(credential_id, data)
|
||||
except Exception as e:
|
||||
# Check if it's a "not found" error
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str or "404" in error_str:
|
||||
logger.debug(f"Credential '{credential_id}' not found in Vault")
|
||||
return None
|
||||
logger.error(f"Failed to load credential '{credential_id}' from Vault: {e}")
|
||||
raise
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete credential from Vault (all versions)."""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
self._client.secrets.kv.v2.delete_metadata_and_all_versions(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
logger.debug(f"Deleted credential '{credential_id}' from Vault")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str or "404" in error_str:
|
||||
return False
|
||||
logger.error(f"Failed to delete credential '{credential_id}' from Vault: {e}")
|
||||
raise
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credentials under the prefix."""
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.list_secrets(
|
||||
path=self._prefix,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
keys = response.get("data", {}).get("keys", [])
|
||||
# Remove trailing slashes from folder names
|
||||
return [k.rstrip("/") for k in keys]
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str or "404" in error_str:
|
||||
return []
|
||||
logger.error(f"Failed to list credentials from Vault: {e}")
|
||||
raise
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists in Vault."""
|
||||
try:
|
||||
path = self._path(credential_id)
|
||||
self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _serialize_for_vault(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""Convert credential to Vault secret format."""
|
||||
data: dict[str, Any] = {
|
||||
"_type": credential.credential_type.value,
|
||||
}
|
||||
|
||||
if credential.provider_id:
|
||||
data["_provider_id"] = credential.provider_id
|
||||
|
||||
if credential.description:
|
||||
data["_description"] = credential.description
|
||||
|
||||
if credential.auto_refresh:
|
||||
data["_auto_refresh"] = "true"
|
||||
|
||||
# Store each key
|
||||
for key_name, key in credential.keys.items():
|
||||
data[key_name] = key.get_secret_value()
|
||||
|
||||
if key.expires_at:
|
||||
data[f"_expires_{key_name}"] = key.expires_at.isoformat()
|
||||
|
||||
if key.metadata:
|
||||
data[f"_metadata_{key_name}"] = str(key.metadata)
|
||||
|
||||
return data
|
||||
|
||||
def _deserialize_from_vault(self, credential_id: str, data: dict[str, Any]) -> CredentialObject:
|
||||
"""Reconstruct credential from Vault secret."""
|
||||
# Extract metadata fields
|
||||
cred_type = CredentialType(data.pop("_type", "api_key"))
|
||||
provider_id = data.pop("_provider_id", None)
|
||||
description = data.pop("_description", "")
|
||||
auto_refresh = data.pop("_auto_refresh", "") == "true"
|
||||
|
||||
# Build keys dict
|
||||
keys: dict[str, CredentialKey] = {}
|
||||
|
||||
# Find all non-metadata keys
|
||||
key_names = [k for k in data.keys() if not k.startswith("_")]
|
||||
|
||||
for key_name in key_names:
|
||||
value = data[key_name]
|
||||
|
||||
# Check for expiration
|
||||
expires_at = None
|
||||
expires_key = f"_expires_{key_name}"
|
||||
if expires_key in data:
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(data[expires_key])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Check for metadata
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata_key = f"_metadata_{key_name}"
|
||||
if metadata_key in data:
|
||||
try:
|
||||
import ast
|
||||
|
||||
metadata = ast.literal_eval(data[metadata_key])
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
keys[key_name] = CredentialKey(
|
||||
name=key_name,
|
||||
value=SecretStr(value),
|
||||
expires_at=expires_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return CredentialObject(
|
||||
id=credential_id,
|
||||
credential_type=cred_type,
|
||||
keys=keys,
|
||||
provider_id=provider_id,
|
||||
description=description,
|
||||
auto_refresh=auto_refresh,
|
||||
)
|
||||
|
||||
# --- Vault-Specific Operations ---
|
||||
|
||||
def get_secret_metadata(self, credential_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get Vault metadata for a secret (version info, timestamps, etc.).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
Metadata dict or None if not found
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.read_secret_metadata(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return response.get("data", {})
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def soft_delete(self, credential_id: str, versions: list[int] | None = None) -> bool:
|
||||
"""
|
||||
Soft delete specific versions (can be recovered).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
versions: Version numbers to delete. If None, deletes latest.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
if versions:
|
||||
self._client.secrets.kv.v2.delete_secret_versions(
|
||||
path=path,
|
||||
versions=versions,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
else:
|
||||
self._client.secrets.kv.v2.delete_latest_version_of_secret(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Soft delete failed for '{credential_id}': {e}")
|
||||
return False
|
||||
|
||||
def undelete(self, credential_id: str, versions: list[int]) -> bool:
|
||||
"""
|
||||
Recover soft-deleted versions.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
versions: Version numbers to recover
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
self._client.secrets.kv.v2.undelete_secret_versions(
|
||||
path=path,
|
||||
versions=versions,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Undelete failed for '{credential_id}': {e}")
|
||||
return False
|
||||
|
||||
def load_version(self, credential_id: str, version: int) -> CredentialObject | None:
|
||||
"""
|
||||
Load a specific version of a credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
version: Version number to load
|
||||
|
||||
Returns:
|
||||
CredentialObject or None
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
version=version,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
data = response["data"]["data"]
|
||||
return self._deserialize_from_vault(credential_id, data)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
|
||||
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec
|
||||
from framework.graph.conversation import ConversationStore, Message, NodeConversation
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
|
||||
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
|
||||
@@ -42,6 +43,7 @@ __all__ = [
|
||||
# Edge
|
||||
"EdgeSpec",
|
||||
"EdgeCondition",
|
||||
"GraphSpec",
|
||||
# Executor (fixed graph)
|
||||
"GraphExecutor",
|
||||
# Plan (flexible execution)
|
||||
@@ -71,4 +73,8 @@ __all__ = [
|
||||
"CodeSandbox",
|
||||
"safe_exec",
|
||||
"safe_eval",
|
||||
# Conversation
|
||||
"NodeConversation",
|
||||
"ConversationStore",
|
||||
"Message",
|
||||
]
|
||||
|
||||
@@ -170,13 +170,17 @@ class CodeValidator:
|
||||
# Check for dangerous attribute access
|
||||
if isinstance(node, ast.Attribute):
|
||||
if node.attr.startswith("_"):
|
||||
issues.append(f"Access to private attribute '{node.attr}' at line {node.lineno}")
|
||||
issues.append(
|
||||
f"Access to private attribute '{node.attr}' at line {node.lineno}"
|
||||
)
|
||||
|
||||
# Check for exec/eval calls
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in ("exec", "eval", "compile", "__import__"):
|
||||
issues.append(f"Blocked function call: {node.func.id} at line {node.lineno}")
|
||||
issues.append(
|
||||
f"Blocked function call: {node.func.id} at line {node.lineno}"
|
||||
)
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
@@ -0,0 +1,426 @@
|
||||
"""NodeConversation: Message history management for graph nodes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A single message in a conversation.
|
||||
|
||||
Attributes:
|
||||
seq: Monotonic sequence number.
|
||||
role: One of "user", "assistant", or "tool".
|
||||
content: Message text.
|
||||
tool_use_id: Internal tool-use identifier (output as ``tool_call_id`` in LLM dicts).
|
||||
tool_calls: OpenAI-format tool call list for assistant messages.
|
||||
is_error: When True and role is "tool", ``to_llm_dict`` prepends "ERROR: " to content.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
role: Literal["user", "assistant", "tool"]
|
||||
content: str
|
||||
tool_use_id: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
is_error: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
if self.role == "user":
|
||||
return {"role": "user", "content": self.content}
|
||||
|
||||
if self.role == "assistant":
|
||||
d: dict[str, Any] = {"role": "assistant", "content": self.content}
|
||||
if self.tool_calls:
|
||||
d["tool_calls"] = self.tool_calls
|
||||
return d
|
||||
|
||||
# role == "tool"
|
||||
content = f"ERROR: {self.content}" if self.is_error else self.content
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_use_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def to_storage_dict(self) -> dict[str, Any]:
|
||||
"""Serialize all fields for persistence. Omits None/default-False fields."""
|
||||
d: dict[str, Any] = {
|
||||
"seq": self.seq,
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
}
|
||||
if self.tool_use_id is not None:
|
||||
d["tool_use_id"] = self.tool_use_id
|
||||
if self.tool_calls is not None:
|
||||
d["tool_calls"] = self.tool_calls
|
||||
if self.is_error:
|
||||
d["is_error"] = self.is_error
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_storage_dict(cls, data: dict[str, Any]) -> Message:
|
||||
"""Deserialize from a storage dict."""
|
||||
return cls(
|
||||
seq=data["seq"],
|
||||
role=data["role"],
|
||||
content=data["content"],
|
||||
tool_use_id=data.get("tool_use_id"),
|
||||
tool_calls=data.get("tool_calls"),
|
||||
is_error=data.get("is_error", False),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationStore protocol (Phase 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationStore(Protocol):
|
||||
"""Protocol for conversation persistence backends."""
|
||||
|
||||
async def write_part(self, seq: int, data: dict[str, Any]) -> None: ...
|
||||
|
||||
async def read_parts(self) -> list[dict[str, Any]]: ...
|
||||
|
||||
async def write_meta(self, data: dict[str, Any]) -> None: ...
|
||||
|
||||
async def read_meta(self) -> dict[str, Any] | None: ...
|
||||
|
||||
async def write_cursor(self, data: dict[str, Any]) -> None: ...
|
||||
|
||||
async def read_cursor(self) -> dict[str, Any] | None: ...
|
||||
|
||||
async def delete_parts_before(self, seq: int) -> None: ...
|
||||
|
||||
async def close(self) -> None: ...
|
||||
|
||||
async def destroy(self) -> None: ...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NodeConversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class NodeConversation:
|
||||
"""Message history for a graph node with optional write-through persistence.
|
||||
|
||||
When *store* is ``None`` the conversation works purely in-memory.
|
||||
When a :class:`ConversationStore` is supplied every mutation is
|
||||
persisted via write-through (meta is lazily written on the first
|
||||
``_persist`` call).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str = "",
|
||||
max_history_tokens: int = 32000,
|
||||
compaction_threshold: float = 0.8,
|
||||
output_keys: list[str] | None = None,
|
||||
store: ConversationStore | None = None,
|
||||
) -> None:
|
||||
self._system_prompt = system_prompt
|
||||
self._max_history_tokens = max_history_tokens
|
||||
self._compaction_threshold = compaction_threshold
|
||||
self._output_keys = output_keys
|
||||
self._store = store
|
||||
self._messages: list[Message] = []
|
||||
self._next_seq: int = 0
|
||||
self._meta_persisted: bool = False
|
||||
|
||||
# --- Properties --------------------------------------------------------
|
||||
|
||||
@property
|
||||
def system_prompt(self) -> str:
|
||||
return self._system_prompt
|
||||
|
||||
@property
|
||||
def messages(self) -> list[Message]:
|
||||
"""Return a defensive copy of the message list."""
|
||||
return list(self._messages)
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Number of conversational turns (one turn = one user message)."""
|
||||
return sum(1 for m in self._messages if m.role == "user")
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""Total number of messages (all roles)."""
|
||||
return len(self._messages)
|
||||
|
||||
@property
|
||||
def next_seq(self) -> int:
|
||||
return self._next_seq
|
||||
|
||||
# --- Add messages ------------------------------------------------------
|
||||
|
||||
async def add_user_message(self, content: str) -> Message:
|
||||
msg = Message(seq=self._next_seq, role="user", content=content)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
return msg
|
||||
|
||||
async def add_assistant_message(
|
||||
self,
|
||||
content: str,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
return msg
|
||||
|
||||
async def add_tool_result(
|
||||
self,
|
||||
tool_use_id: str,
|
||||
content: str,
|
||||
is_error: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
role="tool",
|
||||
content=content,
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
await self._persist(msg)
|
||||
return msg
|
||||
|
||||
# --- Query -------------------------------------------------------------
|
||||
|
||||
def to_llm_messages(self) -> list[dict[str, Any]]:
|
||||
"""Return messages as OpenAI-format dicts (system prompt excluded)."""
|
||||
return [m.to_llm_dict() for m in self._messages]
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Rough token estimate: total characters / 4."""
|
||||
total_chars = sum(len(m.content) for m in self._messages)
|
||||
return total_chars // 4
|
||||
|
||||
def needs_compaction(self) -> bool:
|
||||
return self.estimate_tokens() >= self._max_history_tokens * self._compaction_threshold
|
||||
|
||||
# --- Output-key extraction ---------------------------------------------
|
||||
|
||||
def _extract_protected_values(self, messages: list[Message]) -> dict[str, str]:
|
||||
"""Scan assistant messages for output_key values before compaction.
|
||||
|
||||
Iterates most-recent-first. Once a key is found, it's skipped for
|
||||
older messages (latest value wins).
|
||||
"""
|
||||
if not self._output_keys:
|
||||
return {}
|
||||
|
||||
found: dict[str, str] = {}
|
||||
remaining_keys = set(self._output_keys)
|
||||
|
||||
for msg in reversed(messages):
|
||||
if msg.role != "assistant" or not remaining_keys:
|
||||
continue
|
||||
|
||||
for key in list(remaining_keys):
|
||||
value = self._try_extract_key(msg.content, key)
|
||||
if value is not None:
|
||||
found[key] = value
|
||||
remaining_keys.discard(key)
|
||||
|
||||
return found
|
||||
|
||||
def _try_extract_key(self, content: str, key: str) -> str | None:
|
||||
"""Try 4 strategies to extract a key's value from message content."""
|
||||
from framework.graph.node import find_json_object
|
||||
|
||||
# 1. Whole message is JSON
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 2. Embedded JSON via find_json_object
|
||||
json_str = find_json_object(content)
|
||||
if json_str:
|
||||
try:
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict) and key in parsed:
|
||||
val = parsed[key]
|
||||
return json.dumps(val) if not isinstance(val, str) else val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Colon format: key: value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*:\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 4. Equals format: key = value
|
||||
match = re.search(rf"\b{re.escape(key)}\s*=\s*(.+)", content)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
# --- Lifecycle ---------------------------------------------------------
|
||||
|
||||
async def compact(self, summary: str, keep_recent: int = 2) -> None:
|
||||
"""Replace old messages with a summary, optionally keeping recent ones.
|
||||
|
||||
Args:
|
||||
summary: Caller-provided summary text.
|
||||
keep_recent: Number of recent messages to preserve (default 2).
|
||||
Clamped to [0, len(messages) - 1].
|
||||
"""
|
||||
if not self._messages:
|
||||
return
|
||||
|
||||
# Clamp: must discard at least 1 message
|
||||
keep_recent = max(0, min(keep_recent, len(self._messages) - 1))
|
||||
|
||||
if keep_recent > 0:
|
||||
old_messages = self._messages[:-keep_recent]
|
||||
recent_messages = self._messages[-keep_recent:]
|
||||
else:
|
||||
old_messages = self._messages
|
||||
recent_messages = []
|
||||
|
||||
# Extract protected values from messages being discarded
|
||||
if self._output_keys:
|
||||
protected = self._extract_protected_values(old_messages)
|
||||
if protected:
|
||||
lines = ["PRESERVED VALUES (do not lose these):"]
|
||||
for k, v in protected.items():
|
||||
lines.append(f"- {k}: {v}")
|
||||
lines.append("")
|
||||
lines.append("CONVERSATION SUMMARY:")
|
||||
lines.append(summary)
|
||||
summary = "\n".join(lines)
|
||||
|
||||
# Determine summary seq
|
||||
if recent_messages:
|
||||
summary_seq = recent_messages[0].seq - 1
|
||||
else:
|
||||
summary_seq = self._next_seq
|
||||
self._next_seq += 1
|
||||
|
||||
summary_msg = Message(seq=summary_seq, role="user", content=summary)
|
||||
|
||||
# Persist
|
||||
if self._store:
|
||||
delete_before = recent_messages[0].seq if recent_messages else self._next_seq
|
||||
await self._store.delete_parts_before(delete_before)
|
||||
await self._store.write_part(summary_msg.seq, summary_msg.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
|
||||
if self._store:
|
||||
await self._store.delete_parts_before(self._next_seq)
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
self._messages.clear()
|
||||
|
||||
def export_summary(self) -> str:
|
||||
"""Structured summary with [STATS], [CONFIG], [RECENT_MESSAGES] sections."""
|
||||
prompt_preview = (
|
||||
self._system_prompt[:80] + "..."
|
||||
if len(self._system_prompt) > 80
|
||||
else self._system_prompt
|
||||
)
|
||||
|
||||
lines = [
|
||||
"[STATS]",
|
||||
f"turns: {self.turn_count}",
|
||||
f"messages: {self.message_count}",
|
||||
f"estimated_tokens: {self.estimate_tokens()}",
|
||||
"",
|
||||
"[CONFIG]",
|
||||
f"system_prompt: {prompt_preview!r}",
|
||||
]
|
||||
|
||||
if self._output_keys:
|
||||
lines.append(f"output_keys: {', '.join(self._output_keys)}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("[RECENT_MESSAGES]")
|
||||
for m in self._messages[-5:]:
|
||||
preview = m.content[:60] + "..." if len(m.content) > 60 else m.content
|
||||
lines.append(f" [{m.role}] {preview}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# --- Persistence internals ---------------------------------------------
|
||||
|
||||
async def _persist(self, message: Message) -> None:
|
||||
"""Write-through a single message. No-op when store is None."""
|
||||
if self._store is None:
|
||||
return
|
||||
if not self._meta_persisted:
|
||||
await self._persist_meta()
|
||||
await self._store.write_part(message.seq, message.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
async def _persist_meta(self) -> None:
|
||||
"""Lazily write conversation metadata to the store (called once)."""
|
||||
if self._store is None:
|
||||
return
|
||||
await self._store.write_meta(
|
||||
{
|
||||
"system_prompt": self._system_prompt,
|
||||
"max_history_tokens": self._max_history_tokens,
|
||||
"compaction_threshold": self._compaction_threshold,
|
||||
"output_keys": self._output_keys,
|
||||
}
|
||||
)
|
||||
self._meta_persisted = True
|
||||
|
||||
# --- Restore -----------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def restore(cls, store: ConversationStore) -> NodeConversation | None:
|
||||
"""Reconstruct a NodeConversation from a store.
|
||||
|
||||
Returns ``None`` if the store contains no metadata (i.e. the
|
||||
conversation was never persisted).
|
||||
"""
|
||||
meta = await store.read_meta()
|
||||
if meta is None:
|
||||
return None
|
||||
|
||||
conv = cls(
|
||||
system_prompt=meta.get("system_prompt", ""),
|
||||
max_history_tokens=meta.get("max_history_tokens", 32000),
|
||||
compaction_threshold=meta.get("compaction_threshold", 0.8),
|
||||
output_keys=meta.get("output_keys"),
|
||||
store=store,
|
||||
)
|
||||
conv._meta_persisted = True
|
||||
|
||||
parts = await store.read_parts()
|
||||
conv._messages = [Message.from_storage_dict(p) for p in parts]
|
||||
|
||||
cursor = await store.read_cursor()
|
||||
if cursor:
|
||||
conv._next_seq = cursor["next_seq"]
|
||||
elif conv._messages:
|
||||
conv._next_seq = conv._messages[-1].seq + 1
|
||||
|
||||
return conv
|
||||
+180
-10
@@ -13,7 +13,7 @@ Edge Types:
|
||||
- always: Always traverse after source completes
|
||||
- on_success: Traverse only if source succeeds
|
||||
- on_failure: Traverse only if source fails
|
||||
- conditional: Traverse based on expression evaluation
|
||||
- conditional: Traverse based on expression evaluation (SAFE SUBSET ONLY)
|
||||
- llm_decide: Let LLM decide based on goal and context (goal-aware routing)
|
||||
|
||||
The llm_decide condition is particularly powerful for goal-driven agents,
|
||||
@@ -26,6 +26,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
|
||||
class EdgeCondition(str, Enum):
|
||||
"""When an edge should be traversed."""
|
||||
@@ -169,8 +171,8 @@ class EdgeSpec(BaseModel):
|
||||
}
|
||||
|
||||
try:
|
||||
# Safe evaluation (in production, use a proper expression evaluator)
|
||||
return bool(eval(self.condition_expr, {"__builtins__": {}}, context))
|
||||
# Safe evaluation using AST-based whitelist
|
||||
return bool(safe_eval(self.condition_expr, context))
|
||||
except Exception as e:
|
||||
# Log the error for debugging
|
||||
import logging
|
||||
@@ -291,13 +293,52 @@ Respond with ONLY a JSON object:
|
||||
return result
|
||||
|
||||
|
||||
class AsyncEntryPointSpec(BaseModel):
|
||||
"""
|
||||
Specification for an asynchronous entry point.
|
||||
|
||||
Used with AgentRuntime for multi-entry-point agents that handle
|
||||
concurrent execution streams (e.g., webhook + API handlers).
|
||||
|
||||
Example:
|
||||
AsyncEntryPointSpec(
|
||||
id="webhook",
|
||||
name="Zendesk Webhook Handler",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
isolation_level="shared",
|
||||
)
|
||||
"""
|
||||
|
||||
id: str = Field(description="Unique identifier for this entry point")
|
||||
name: str = Field(description="Human-readable name")
|
||||
entry_node: str = Field(description="Node ID to start execution from")
|
||||
trigger_type: str = Field(
|
||||
default="manual",
|
||||
description="How this entry point is triggered: webhook, api, timer, event, manual",
|
||||
)
|
||||
trigger_config: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Trigger-specific configuration (e.g., webhook URL, timer interval)",
|
||||
)
|
||||
isolation_level: str = Field(
|
||||
default="shared", description="State isolation: isolated, shared, or synchronized"
|
||||
)
|
||||
priority: int = Field(default=0, description="Execution priority (higher = more priority)")
|
||||
max_concurrent: int = Field(
|
||||
default=10, description="Maximum concurrent executions for this entry point"
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class GraphSpec(BaseModel):
|
||||
"""
|
||||
Complete specification of an agent graph.
|
||||
|
||||
Contains all nodes, edges, and metadata needed to execute.
|
||||
|
||||
Example:
|
||||
For single-entry-point agents (traditional pattern):
|
||||
GraphSpec(
|
||||
id="calculator-graph",
|
||||
goal_id="calc-001",
|
||||
@@ -306,6 +347,29 @@ class GraphSpec(BaseModel):
|
||||
nodes=[...],
|
||||
edges=[...],
|
||||
)
|
||||
|
||||
For multi-entry-point agents (concurrent streams):
|
||||
GraphSpec(
|
||||
id="support-agent-graph",
|
||||
goal_id="support-001",
|
||||
entry_node="process-webhook", # Default entry
|
||||
async_entry_points=[
|
||||
AsyncEntryPointSpec(
|
||||
id="webhook",
|
||||
name="Zendesk Webhook",
|
||||
entry_node="process-webhook",
|
||||
trigger_type="webhook",
|
||||
),
|
||||
AsyncEntryPointSpec(
|
||||
id="api",
|
||||
name="API Handler",
|
||||
entry_node="process-request",
|
||||
trigger_type="api",
|
||||
),
|
||||
],
|
||||
nodes=[...],
|
||||
edges=[...],
|
||||
)
|
||||
"""
|
||||
|
||||
id: str
|
||||
@@ -318,8 +382,18 @@ class GraphSpec(BaseModel):
|
||||
default_factory=dict,
|
||||
description="Named entry points for resuming execution. Format: {name: node_id}",
|
||||
)
|
||||
terminal_nodes: list[str] = Field(default_factory=list, description="IDs of nodes that end execution")
|
||||
pause_nodes: list[str] = Field(default_factory=list, description="IDs of nodes that pause execution for HITL input")
|
||||
async_entry_points: list[AsyncEntryPointSpec] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Asynchronous entry points for concurrent execution streams (used with AgentRuntime)"
|
||||
),
|
||||
)
|
||||
terminal_nodes: list[str] = Field(
|
||||
default_factory=list, description="IDs of nodes that end execution"
|
||||
)
|
||||
pause_nodes: list[str] = Field(
|
||||
default_factory=list, description="IDs of nodes that pause execution for HITL input"
|
||||
)
|
||||
|
||||
# Components
|
||||
nodes: list[Any] = Field( # NodeSpec, but avoiding circular import
|
||||
@@ -328,12 +402,19 @@ class GraphSpec(BaseModel):
|
||||
edges: list[EdgeSpec] = Field(default_factory=list, description="All edge specifications")
|
||||
|
||||
# Shared memory keys
|
||||
memory_keys: list[str] = Field(default_factory=list, description="Keys available in shared memory")
|
||||
memory_keys: list[str] = Field(
|
||||
default_factory=list, description="Keys available in shared memory"
|
||||
)
|
||||
|
||||
# Default LLM settings
|
||||
default_model: str = "claude-haiku-4-5-20251001"
|
||||
max_tokens: int = 1024
|
||||
|
||||
# Cleanup LLM for JSON extraction fallback (fast/cheap model preferred)
|
||||
# If not set, uses CEREBRAS_API_KEY -> cerebras/llama-3.3-70b or
|
||||
# ANTHROPIC_API_KEY -> claude-3-5-haiku as fallback
|
||||
cleanup_llm_model: str | None = None
|
||||
|
||||
# Execution limits
|
||||
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
|
||||
max_retries_per_node: int = 3
|
||||
@@ -351,6 +432,17 @@ class GraphSpec(BaseModel):
|
||||
return node
|
||||
return None
|
||||
|
||||
def has_async_entry_points(self) -> bool:
|
||||
"""Check if this graph uses async entry points (multi-stream execution)."""
|
||||
return len(self.async_entry_points) > 0
|
||||
|
||||
def get_async_entry_point(self, entry_point_id: str) -> AsyncEntryPointSpec | None:
|
||||
"""Get an async entry point by ID."""
|
||||
for ep in self.async_entry_points:
|
||||
if ep.id == entry_point_id:
|
||||
return ep
|
||||
return None
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> list[EdgeSpec]:
|
||||
"""Get all edges leaving a node, sorted by priority."""
|
||||
edges = [e for e in self.edges if e.source == node_id]
|
||||
@@ -360,6 +452,42 @@ class GraphSpec(BaseModel):
|
||||
"""Get all edges entering a node."""
|
||||
return [e for e in self.edges if e.target == node_id]
|
||||
|
||||
def detect_fan_out_nodes(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Detect nodes that fan-out to multiple targets.
|
||||
|
||||
A fan-out occurs when a node has multiple outgoing edges with the same
|
||||
condition (typically ON_SUCCESS) that should execute in parallel.
|
||||
|
||||
Returns:
|
||||
Dict mapping source_node_id -> list of parallel target_node_ids
|
||||
"""
|
||||
fan_outs: dict[str, list[str]] = {}
|
||||
for node in self.nodes:
|
||||
outgoing = self.get_outgoing_edges(node.id)
|
||||
# Fan-out: multiple edges with ON_SUCCESS condition
|
||||
success_edges = [e for e in outgoing if e.condition == EdgeCondition.ON_SUCCESS]
|
||||
if len(success_edges) > 1:
|
||||
fan_outs[node.id] = [e.target for e in success_edges]
|
||||
return fan_outs
|
||||
|
||||
def detect_fan_in_nodes(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Detect nodes that receive from multiple sources (fan-in / convergence).
|
||||
|
||||
A fan-in occurs when a node has multiple incoming edges, meaning
|
||||
it should wait for all predecessor branches to complete.
|
||||
|
||||
Returns:
|
||||
Dict mapping target_node_id -> list of source_node_ids
|
||||
"""
|
||||
fan_ins: dict[str, list[str]] = {}
|
||||
for node in self.nodes:
|
||||
incoming = self.get_incoming_edges(node.id)
|
||||
if len(incoming) > 1:
|
||||
fan_ins[node.id] = [e.source for e in incoming]
|
||||
return fan_ins
|
||||
|
||||
def get_entry_point(self, session_state: dict | None = None) -> str:
|
||||
"""
|
||||
Get the appropriate entry point based on session state.
|
||||
@@ -400,6 +528,37 @@ class GraphSpec(BaseModel):
|
||||
if not self.get_node(self.entry_node):
|
||||
errors.append(f"Entry node '{self.entry_node}' not found")
|
||||
|
||||
# Check async entry points
|
||||
seen_entry_ids = set()
|
||||
for entry_point in self.async_entry_points:
|
||||
# Check for duplicate IDs
|
||||
if entry_point.id in seen_entry_ids:
|
||||
errors.append(f"Duplicate async entry point ID: '{entry_point.id}'")
|
||||
seen_entry_ids.add(entry_point.id)
|
||||
|
||||
# Check entry node exists
|
||||
if not self.get_node(entry_point.entry_node):
|
||||
errors.append(
|
||||
f"Async entry point '{entry_point.id}' references "
|
||||
f"missing node '{entry_point.entry_node}'"
|
||||
)
|
||||
|
||||
# Validate isolation level
|
||||
valid_isolation = {"isolated", "shared", "synchronized"}
|
||||
if entry_point.isolation_level not in valid_isolation:
|
||||
errors.append(
|
||||
f"Async entry point '{entry_point.id}' has invalid isolation_level "
|
||||
f"'{entry_point.isolation_level}'. Valid: {valid_isolation}"
|
||||
)
|
||||
|
||||
# Validate trigger type
|
||||
valid_triggers = {"webhook", "api", "timer", "event", "manual"}
|
||||
if entry_point.trigger_type not in valid_triggers:
|
||||
errors.append(
|
||||
f"Async entry point '{entry_point.id}' has invalid trigger_type "
|
||||
f"'{entry_point.trigger_type}'. Valid: {valid_triggers}"
|
||||
)
|
||||
|
||||
# Check terminal nodes exist
|
||||
for term in self.terminal_nodes:
|
||||
if not self.get_node(term):
|
||||
@@ -421,6 +580,10 @@ class GraphSpec(BaseModel):
|
||||
for entry_point_node in self.entry_points.values():
|
||||
to_visit.append(entry_point_node)
|
||||
|
||||
# Add all async entry points as valid starting points
|
||||
for async_entry in self.async_entry_points:
|
||||
to_visit.append(async_entry.entry_node)
|
||||
|
||||
# Traverse from all entry points
|
||||
while to_visit:
|
||||
current = to_visit.pop()
|
||||
@@ -430,11 +593,18 @@ class GraphSpec(BaseModel):
|
||||
for edge in self.get_outgoing_edges(current):
|
||||
to_visit.append(edge.target)
|
||||
|
||||
# Build set of async entry point nodes for quick lookup
|
||||
async_entry_nodes = {ep.entry_node for ep in self.async_entry_points}
|
||||
|
||||
for node in self.nodes:
|
||||
if node.id not in reachable:
|
||||
# Skip this error if the node is a pause node or an entry point target
|
||||
# (pause/resume architecture makes these reachable via session state)
|
||||
if node.id in self.pause_nodes or node.id in self.entry_points.values():
|
||||
# Skip if node is a pause node, entry point target, or async entry
|
||||
# (pause/resume architecture and async entry points make reachable)
|
||||
if (
|
||||
node.id in self.pause_nodes
|
||||
or node.id in self.entry_points.values()
|
||||
or node.id in async_entry_nodes
|
||||
):
|
||||
continue
|
||||
errors.append(f"Node '{node.id}' is unreachable from entry")
|
||||
|
||||
|
||||
@@ -9,12 +9,13 @@ The executor:
|
||||
5. Returns the final result
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.edge import EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import (
|
||||
FunctionNode,
|
||||
@@ -26,6 +27,8 @@ from framework.graph.node import (
|
||||
RouterNode,
|
||||
SharedMemory,
|
||||
)
|
||||
from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
|
||||
from framework.graph.validator import OutputValidator
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
@@ -44,6 +47,52 @@ class ExecutionResult:
|
||||
paused_at: str | None = None # Node ID where execution paused for HITL
|
||||
session_state: dict[str, Any] = field(default_factory=dict) # State to resume from
|
||||
|
||||
# Execution quality metrics
|
||||
total_retries: int = 0 # Total number of retries across all nodes
|
||||
nodes_with_failures: list[str] = field(default_factory=list) # Failed but recovered
|
||||
retry_details: dict[str, int] = field(default_factory=dict) # {node_id: retry_count}
|
||||
had_partial_failures: bool = False # True if any node failed but eventually succeeded
|
||||
execution_quality: str = "clean" # "clean", "degraded", or "failed"
|
||||
|
||||
@property
|
||||
def is_clean_success(self) -> bool:
|
||||
"""True only if execution succeeded with no retries or failures."""
|
||||
return self.success and self.execution_quality == "clean"
|
||||
|
||||
@property
|
||||
def is_degraded_success(self) -> bool:
|
||||
"""True if execution succeeded but had retries or partial failures."""
|
||||
return self.success and self.execution_quality == "degraded"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelBranch:
|
||||
"""Tracks a single branch in parallel fan-out execution."""
|
||||
|
||||
branch_id: str
|
||||
node_id: str
|
||||
edge: EdgeSpec
|
||||
result: "NodeResult | None" = None
|
||||
status: str = "pending" # pending, running, completed, failed
|
||||
retry_count: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelExecutionConfig:
|
||||
"""Configuration for parallel execution behavior."""
|
||||
|
||||
# Error handling: "fail_all" cancels all on first failure,
|
||||
# "continue_others" lets remaining branches complete,
|
||||
# "wait_all" waits for all and reports all failures
|
||||
on_branch_failure: str = "fail_all"
|
||||
|
||||
# Memory conflict handling when branches write same key
|
||||
memory_conflict_strategy: str = "last_wins" # "last_wins", "first_wins", "error"
|
||||
|
||||
# Timeout per branch in seconds
|
||||
branch_timeout_seconds: float = 300.0
|
||||
|
||||
|
||||
class GraphExecutor:
|
||||
"""
|
||||
@@ -72,6 +121,9 @@ class GraphExecutor:
|
||||
tool_executor: Callable | None = None,
|
||||
node_registry: dict[str, NodeProtocol] | None = None,
|
||||
approval_callback: Callable | None = None,
|
||||
cleansing_config: CleansingConfig | None = None,
|
||||
enable_parallel_execution: bool = True,
|
||||
parallel_config: ParallelExecutionConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -83,6 +135,9 @@ class GraphExecutor:
|
||||
tool_executor: Function to execute tools
|
||||
node_registry: Custom node implementations by ID
|
||||
approval_callback: Optional callback for human-in-the-loop approval
|
||||
cleansing_config: Optional output cleansing configuration
|
||||
enable_parallel_execution: Enable parallel fan-out execution (default True)
|
||||
parallel_config: Configuration for parallel execution behavior
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -90,8 +145,43 @@ class GraphExecutor:
|
||||
self.tool_executor = tool_executor
|
||||
self.node_registry = node_registry or {}
|
||||
self.approval_callback = approval_callback
|
||||
self.validator = OutputValidator()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize output cleaner
|
||||
self.cleansing_config = cleansing_config or CleansingConfig()
|
||||
self.output_cleaner = OutputCleaner(
|
||||
config=self.cleansing_config,
|
||||
llm_provider=llm,
|
||||
)
|
||||
|
||||
# Parallel execution settings
|
||||
self.enable_parallel_execution = enable_parallel_execution
|
||||
self._parallel_config = parallel_config or ParallelExecutionConfig()
|
||||
|
||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||
"""
|
||||
Validate that all tools declared by nodes are available.
|
||||
|
||||
Returns:
|
||||
List of error messages (empty if all tools are available)
|
||||
"""
|
||||
errors = []
|
||||
available_tool_names = {t.name for t in self.tools}
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.tools:
|
||||
missing = set(node.tools) - available_tool_names
|
||||
if missing:
|
||||
available = sorted(available_tool_names) if available_tool_names else "none"
|
||||
errors.append(
|
||||
f"Node '{node.name}' (id={node.id}) requires tools "
|
||||
f"{sorted(missing)} but they are not registered. "
|
||||
f"Available tools: {available}"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
@@ -119,16 +209,37 @@ class GraphExecutor:
|
||||
error=f"Invalid graph: {errors}",
|
||||
)
|
||||
|
||||
# Validate tool availability
|
||||
tool_errors = self._validate_tools(graph)
|
||||
if tool_errors:
|
||||
self.logger.error("❌ Tool validation failed:")
|
||||
for err in tool_errors:
|
||||
self.logger.error(f" • {err}")
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=(
|
||||
f"Missing tools: {'; '.join(tool_errors)}. "
|
||||
"Register tools via ToolRegistry or remove tool declarations from nodes."
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize execution state
|
||||
memory = SharedMemory()
|
||||
|
||||
# Restore session state if provided
|
||||
if session_state and "memory" in session_state:
|
||||
# Restore memory from previous session
|
||||
for key, value in session_state["memory"].items():
|
||||
memory.write(key, value)
|
||||
num_keys = len(session_state["memory"])
|
||||
self.logger.info(f"📥 Restored session state with {num_keys} memory keys")
|
||||
memory_data = session_state["memory"]
|
||||
# [RESTORED] Type safety check
|
||||
if not isinstance(memory_data, dict):
|
||||
self.logger.warning(
|
||||
f"⚠️ Invalid memory data type in session state: "
|
||||
f"{type(memory_data).__name__}, expected dict"
|
||||
)
|
||||
else:
|
||||
# Restore memory from previous session
|
||||
for key, value in memory_data.items():
|
||||
memory.write(key, value)
|
||||
self.logger.info(f"📥 Restored session state with {len(memory_data)} memory keys")
|
||||
|
||||
# Write new input data to memory (each key individually)
|
||||
if input_data:
|
||||
@@ -138,6 +249,7 @@ class GraphExecutor:
|
||||
path: list[str] = []
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
node_retry_counts: dict[str, int] = {} # Track retries per node
|
||||
|
||||
# Determine entry point (may differ if resuming)
|
||||
current_node_id = graph.get_entry_point(session_state)
|
||||
@@ -184,6 +296,7 @@ class GraphExecutor:
|
||||
memory=memory,
|
||||
goal=goal,
|
||||
input_data=input_data or {},
|
||||
max_tokens=graph.max_tokens,
|
||||
)
|
||||
|
||||
# Log actual input data being read
|
||||
@@ -199,7 +312,7 @@ class GraphExecutor:
|
||||
self.logger.info(f" {key}: {value_str}")
|
||||
|
||||
# Get or create node implementation
|
||||
node_impl = self._get_node_implementation(node_spec)
|
||||
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||
|
||||
# Validate inputs
|
||||
validation_errors = node_impl.validate_input(ctx)
|
||||
@@ -215,9 +328,29 @@ class GraphExecutor:
|
||||
result = await node_impl.execute(ctx)
|
||||
|
||||
if result.success:
|
||||
tokens = result.tokens_used
|
||||
latency = result.latency_ms
|
||||
self.logger.info(f" ✓ Success (tokens: {tokens}, latency: {latency}ms)")
|
||||
# Validate output before accepting it
|
||||
if result.output and node_spec.output_keys:
|
||||
validation = self.validator.validate_all(
|
||||
output=result.output,
|
||||
expected_keys=node_spec.output_keys,
|
||||
check_hallucination=True,
|
||||
nullable_keys=node_spec.nullable_output_keys,
|
||||
)
|
||||
if not validation.success:
|
||||
self.logger.error(f" ✗ Output validation failed: {validation.error}")
|
||||
result = NodeResult(
|
||||
success=False,
|
||||
error=f"Output validation failed: {validation.error}",
|
||||
output={},
|
||||
tokens_used=result.tokens_used,
|
||||
latency_ms=result.latency_ms,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
self.logger.info(
|
||||
f" ✓ Success (tokens: {result.tokens_used}, "
|
||||
f"latency: {result.latency_ms}ms)"
|
||||
)
|
||||
|
||||
# Generate and log human-readable summary
|
||||
summary = result.to_summary(node_spec)
|
||||
@@ -239,15 +372,71 @@ class GraphExecutor:
|
||||
|
||||
# Handle failure
|
||||
if not result.success:
|
||||
if ctx.attempt < ctx.max_attempts:
|
||||
# Retry
|
||||
ctx.attempt += 1
|
||||
# Track retries per node
|
||||
node_retry_counts[current_node_id] = (
|
||||
node_retry_counts.get(current_node_id, 0) + 1
|
||||
)
|
||||
|
||||
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
|
||||
max_retries = getattr(node_spec, "max_retries", 3)
|
||||
|
||||
if node_retry_counts[current_node_id] < max_retries:
|
||||
# Retry - don't increment steps for retries
|
||||
steps -= 1
|
||||
|
||||
# --- EXPONENTIAL BACKOFF ---
|
||||
retry_count = node_retry_counts[current_node_id]
|
||||
# Backoff formula: 1.0 * (2^(retry - 1)) -> 1s, 2s, 4s...
|
||||
delay = 1.0 * (2 ** (retry_count - 1))
|
||||
self.logger.info(f" Using backoff: Sleeping {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
# --------------------------------------
|
||||
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Move to failure handling
|
||||
# Max retries exceeded - fail the execution
|
||||
self.logger.error(
|
||||
f" ✗ Max retries ({max_retries}) exceeded for node {current_node_id}"
|
||||
)
|
||||
self.runtime.report_problem(
|
||||
severity="critical",
|
||||
description=f"Node {current_node_id} failed: {result.error}",
|
||||
description=(
|
||||
f"Node {current_node_id} failed after "
|
||||
f"{max_retries} attempts: {result.error}"
|
||||
),
|
||||
)
|
||||
self.runtime.end_run(
|
||||
success=False,
|
||||
output_data=memory.read_all(),
|
||||
narrative=(
|
||||
f"Failed at {node_spec.name} after "
|
||||
f"{max_retries} retries: {result.error}"
|
||||
),
|
||||
)
|
||||
|
||||
# Calculate quality metrics
|
||||
total_retries_count = sum(node_retry_counts.values())
|
||||
nodes_failed = list(node_retry_counts.keys())
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=(
|
||||
f"Node '{node_spec.name}' failed after "
|
||||
f"{max_retries} attempts: {result.error}"
|
||||
),
|
||||
output=memory.read_all(),
|
||||
steps_executed=steps,
|
||||
total_tokens=total_tokens,
|
||||
total_latency_ms=total_latency,
|
||||
path=path,
|
||||
total_retries=total_retries_count,
|
||||
nodes_with_failures=nodes_failed,
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
)
|
||||
|
||||
# Check if we just executed a pause node - if so, save state and return
|
||||
@@ -268,6 +457,11 @@ class GraphExecutor:
|
||||
narrative=f"Paused at {node_spec.name} after {steps} steps",
|
||||
)
|
||||
|
||||
# Calculate quality metrics
|
||||
total_retries_count = sum(node_retry_counts.values())
|
||||
nodes_failed = [nid for nid, count in node_retry_counts.items() if count > 0]
|
||||
exec_quality = "degraded" if total_retries_count > 0 else "clean"
|
||||
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
output=saved_memory,
|
||||
@@ -277,6 +471,11 @@ class GraphExecutor:
|
||||
path=path,
|
||||
paused_at=node_spec.id,
|
||||
session_state=session_state_out,
|
||||
total_retries=total_retries_count,
|
||||
nodes_with_failures=nodes_failed,
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
)
|
||||
|
||||
# Check if this is a terminal node - if so, we're done
|
||||
@@ -290,8 +489,8 @@ class GraphExecutor:
|
||||
self.logger.info(f" → Router directing to: {result.next_node}")
|
||||
current_node_id = result.next_node
|
||||
else:
|
||||
# Follow edges
|
||||
next_node = self._follow_edges(
|
||||
# Get all traversable edges for fan-out detection
|
||||
traversable_edges = self._get_all_traversable_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -299,12 +498,59 @@ class GraphExecutor:
|
||||
result=result,
|
||||
memory=memory,
|
||||
)
|
||||
if next_node is None:
|
||||
|
||||
if not traversable_edges:
|
||||
self.logger.info(" → No more edges, ending execution")
|
||||
break # No valid edge, end execution
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
current_node_id = next_node
|
||||
|
||||
# Check for fan-out (multiple traversable edges)
|
||||
if self.enable_parallel_execution and len(traversable_edges) > 1:
|
||||
# Find convergence point (fan-in node)
|
||||
targets = [e.target for e in traversable_edges]
|
||||
fan_in_node = self._find_convergence_node(graph, targets)
|
||||
|
||||
# Execute branches in parallel
|
||||
(
|
||||
_branch_results,
|
||||
branch_tokens,
|
||||
branch_latency,
|
||||
) = await self._execute_parallel_branches(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
edges=traversable_edges,
|
||||
memory=memory,
|
||||
source_result=result,
|
||||
source_node_spec=node_spec,
|
||||
path=path,
|
||||
)
|
||||
|
||||
total_tokens += branch_tokens
|
||||
total_latency += branch_latency
|
||||
|
||||
# Continue from fan-in node
|
||||
if fan_in_node:
|
||||
self.logger.info(f" ⑃ Fan-in: converging at {fan_in_node}")
|
||||
current_node_id = fan_in_node
|
||||
else:
|
||||
# No convergence point - branches are terminal
|
||||
self.logger.info(" → Parallel branches completed (no convergence)")
|
||||
break
|
||||
else:
|
||||
# Sequential: follow single edge (existing logic via _follow_edges)
|
||||
next_node = self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
current_node_spec=node_spec,
|
||||
result=result,
|
||||
memory=memory,
|
||||
)
|
||||
if next_node is None:
|
||||
self.logger.info(" → No more edges, ending execution")
|
||||
break
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
current_node_id = next_node
|
||||
|
||||
# Update input_data for next node
|
||||
input_data = result.output
|
||||
@@ -318,10 +564,24 @@ class GraphExecutor:
|
||||
self.logger.info(f" Total tokens: {total_tokens}")
|
||||
self.logger.info(f" Total latency: {total_latency}ms")
|
||||
|
||||
# Calculate execution quality metrics
|
||||
total_retries_count = sum(node_retry_counts.values())
|
||||
nodes_failed = [nid for nid, count in node_retry_counts.items() if count > 0]
|
||||
exec_quality = "degraded" if total_retries_count > 0 else "clean"
|
||||
|
||||
# Update narrative to reflect execution quality
|
||||
quality_suffix = ""
|
||||
if exec_quality == "degraded":
|
||||
retries = total_retries_count
|
||||
failed = len(nodes_failed)
|
||||
quality_suffix = f" ({retries} retries across {failed} nodes)"
|
||||
|
||||
self.runtime.end_run(
|
||||
success=True,
|
||||
output_data=output,
|
||||
narrative=f"Executed {steps} steps through path: {' -> '.join(path)}",
|
||||
narrative=(
|
||||
f"Executed {steps} steps through path: {' -> '.join(path)}{quality_suffix}"
|
||||
),
|
||||
)
|
||||
|
||||
return ExecutionResult(
|
||||
@@ -331,6 +591,11 @@ class GraphExecutor:
|
||||
total_tokens=total_tokens,
|
||||
total_latency_ms=total_latency,
|
||||
path=path,
|
||||
total_retries=total_retries_count,
|
||||
nodes_with_failures=nodes_failed,
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality=exec_quality,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -342,11 +607,21 @@ class GraphExecutor:
|
||||
success=False,
|
||||
narrative=f"Failed at step {steps}: {e}",
|
||||
)
|
||||
|
||||
# Calculate quality metrics even for exceptions
|
||||
total_retries_count = sum(node_retry_counts.values())
|
||||
nodes_failed = list(node_retry_counts.keys())
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
steps_executed=steps,
|
||||
path=path,
|
||||
total_retries=total_retries_count,
|
||||
nodes_with_failures=nodes_failed,
|
||||
retry_details=dict(node_retry_counts),
|
||||
had_partial_failures=len(nodes_failed) > 0,
|
||||
execution_quality="failed",
|
||||
)
|
||||
|
||||
def _build_context(
|
||||
@@ -355,6 +630,7 @@ class GraphExecutor:
|
||||
memory: SharedMemory,
|
||||
goal: Goal,
|
||||
input_data: dict[str, Any],
|
||||
max_tokens: int = 4096,
|
||||
) -> NodeContext:
|
||||
"""Build execution context for a node."""
|
||||
# Filter tools to those available to this node
|
||||
@@ -378,30 +654,67 @@ class GraphExecutor:
|
||||
available_tools=available_tools,
|
||||
goal_context=goal.to_prompt_context(),
|
||||
goal=goal, # Pass Goal object for LLM-powered routers
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def _get_node_implementation(self, node_spec: NodeSpec) -> NodeProtocol:
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
|
||||
def _get_node_implementation(
|
||||
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||
) -> NodeProtocol:
|
||||
"""Get or create a node implementation."""
|
||||
# Check registry first
|
||||
if node_spec.id in self.node_registry:
|
||||
return self.node_registry[node_spec.id]
|
||||
|
||||
# Validate node type
|
||||
if node_spec.node_type not in self.VALID_NODE_TYPES:
|
||||
raise RuntimeError(
|
||||
f"Invalid node type '{node_spec.node_type}' for node '{node_spec.id}'. "
|
||||
f"Must be one of: {sorted(self.VALID_NODE_TYPES)}. "
|
||||
f"Use 'llm_tool_use' for nodes that call tools, 'llm_generate' for text generation."
|
||||
)
|
||||
|
||||
# Create based on type
|
||||
if node_spec.node_type == "llm_tool_use":
|
||||
return LLMNode(tool_executor=self.tool_executor)
|
||||
if not node_spec.tools:
|
||||
raise RuntimeError(
|
||||
f"Node '{node_spec.id}' is type 'llm_tool_use' but declares no tools. "
|
||||
"Either add tools to the node or change type to 'llm_generate'."
|
||||
)
|
||||
return LLMNode(
|
||||
tool_executor=self.tool_executor,
|
||||
require_tools=True,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "llm_generate":
|
||||
return LLMNode()
|
||||
return LLMNode(
|
||||
tool_executor=None,
|
||||
require_tools=False,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "router":
|
||||
return RouterNode()
|
||||
|
||||
if node_spec.node_type == "function":
|
||||
# Function nodes need explicit registration
|
||||
raise RuntimeError(f"Function node '{node_spec.id}' not registered. Register with node_registry.")
|
||||
raise RuntimeError(
|
||||
f"Function node '{node_spec.id}' not registered. Register with node_registry."
|
||||
)
|
||||
|
||||
# Default to LLM node
|
||||
return LLMNode(tool_executor=self.tool_executor)
|
||||
if node_spec.node_type == "human_input":
|
||||
# Human input nodes are handled specially by HITL mechanism
|
||||
return LLMNode(
|
||||
tool_executor=None,
|
||||
require_tools=False,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
|
||||
def _follow_edges(
|
||||
self,
|
||||
@@ -427,15 +740,292 @@ class GraphExecutor:
|
||||
source_node_name=current_node_spec.name if current_node_spec else current_node_id,
|
||||
target_node_name=target_node_spec.name if target_node_spec else edge.target,
|
||||
):
|
||||
# Map inputs
|
||||
# Validate and clean output before mapping inputs
|
||||
if self.cleansing_config.enabled and target_node_spec:
|
||||
output_to_validate = result.output
|
||||
|
||||
validation = self.output_cleaner.validate_output(
|
||||
output=output_to_validate,
|
||||
source_node_id=current_node_id,
|
||||
target_node_spec=target_node_spec,
|
||||
)
|
||||
|
||||
if not validation.valid:
|
||||
self.logger.warning(f"⚠ Output validation failed: {validation.errors}")
|
||||
|
||||
# Clean the output
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
output=output_to_validate,
|
||||
source_node_id=current_node_id,
|
||||
target_node_spec=target_node_spec,
|
||||
validation_errors=validation.errors,
|
||||
)
|
||||
|
||||
# Update result with cleaned output
|
||||
result.output = cleaned_output
|
||||
|
||||
# Write cleaned output back to memory (skip validation for LLM output)
|
||||
for key, value in cleaned_output.items():
|
||||
memory.write(key, value, validate=False)
|
||||
|
||||
# Revalidate
|
||||
revalidation = self.output_cleaner.validate_output(
|
||||
output=cleaned_output,
|
||||
source_node_id=current_node_id,
|
||||
target_node_spec=target_node_spec,
|
||||
)
|
||||
|
||||
if revalidation.valid:
|
||||
self.logger.info("✓ Output cleaned and validated successfully")
|
||||
else:
|
||||
self.logger.error(
|
||||
f"✗ Cleaning failed, errors remain: {revalidation.errors}"
|
||||
)
|
||||
# Continue anyway if fallback_to_raw is True
|
||||
|
||||
# Map inputs (skip validation for processed LLM output)
|
||||
mapped = edge.map_inputs(result.output, memory.read_all())
|
||||
for key, value in mapped.items():
|
||||
memory.write(key, value)
|
||||
memory.write(key, value, validate=False)
|
||||
|
||||
return edge.target
|
||||
|
||||
return None
|
||||
|
||||
def _get_all_traversable_edges(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
current_node_id: str,
|
||||
current_node_spec: Any,
|
||||
result: NodeResult,
|
||||
memory: SharedMemory,
|
||||
) -> list[EdgeSpec]:
|
||||
"""
|
||||
Get ALL edges that should be traversed (for fan-out detection).
|
||||
|
||||
Unlike _follow_edges which returns the first match, this returns
|
||||
all matching edges to enable parallel execution.
|
||||
"""
|
||||
edges = graph.get_outgoing_edges(current_node_id)
|
||||
traversable = []
|
||||
|
||||
for edge in edges:
|
||||
target_node_spec = graph.get_node(edge.target)
|
||||
if edge.should_traverse(
|
||||
source_success=result.success,
|
||||
source_output=result.output,
|
||||
memory=memory.read_all(),
|
||||
llm=self.llm,
|
||||
goal=goal,
|
||||
source_node_name=current_node_spec.name if current_node_spec else current_node_id,
|
||||
target_node_name=target_node_spec.name if target_node_spec else edge.target,
|
||||
):
|
||||
traversable.append(edge)
|
||||
|
||||
return traversable
|
||||
|
||||
def _find_convergence_node(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
parallel_targets: list[str],
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the common target node where parallel branches converge (fan-in).
|
||||
|
||||
Args:
|
||||
graph: The graph specification
|
||||
parallel_targets: List of node IDs that are running in parallel
|
||||
|
||||
Returns:
|
||||
Node ID where all branches converge, or None if no convergence
|
||||
"""
|
||||
# Get all nodes that parallel branches lead to
|
||||
next_nodes: dict[str, int] = {} # node_id -> count of branches leading to it
|
||||
|
||||
for target in parallel_targets:
|
||||
outgoing = graph.get_outgoing_edges(target)
|
||||
for edge in outgoing:
|
||||
next_nodes[edge.target] = next_nodes.get(edge.target, 0) + 1
|
||||
|
||||
# Convergence node is where ALL branches lead
|
||||
for node_id, count in next_nodes.items():
|
||||
if count == len(parallel_targets):
|
||||
return node_id
|
||||
|
||||
# Fallback: return most common target if any
|
||||
if next_nodes:
|
||||
return max(next_nodes.keys(), key=lambda k: next_nodes[k])
|
||||
|
||||
return None
|
||||
|
||||
async def _execute_parallel_branches(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
edges: list[EdgeSpec],
|
||||
memory: SharedMemory,
|
||||
source_result: NodeResult,
|
||||
source_node_spec: Any,
|
||||
path: list[str],
|
||||
) -> tuple[dict[str, NodeResult], int, int]:
|
||||
"""
|
||||
Execute multiple branches in parallel using asyncio.gather.
|
||||
|
||||
Args:
|
||||
graph: The graph specification
|
||||
goal: The execution goal
|
||||
edges: List of edges to follow in parallel
|
||||
memory: Shared memory instance
|
||||
source_result: Result from the source node
|
||||
source_node_spec: Spec of the source node
|
||||
path: Execution path list to update
|
||||
|
||||
Returns:
|
||||
Tuple of (branch_results dict, total_tokens, total_latency)
|
||||
"""
|
||||
branches: dict[str, ParallelBranch] = {}
|
||||
|
||||
# Create branches for each edge
|
||||
for edge in edges:
|
||||
branch_id = f"{edge.source}_to_{edge.target}"
|
||||
branches[branch_id] = ParallelBranch(
|
||||
branch_id=branch_id,
|
||||
node_id=edge.target,
|
||||
edge=edge,
|
||||
)
|
||||
|
||||
self.logger.info(f" ⑂ Fan-out: executing {len(branches)} branches in parallel")
|
||||
for branch in branches.values():
|
||||
target_spec = graph.get_node(branch.node_id)
|
||||
self.logger.info(f" • {target_spec.name if target_spec else branch.node_id}")
|
||||
|
||||
async def execute_single_branch(
|
||||
branch: ParallelBranch,
|
||||
) -> tuple[ParallelBranch, NodeResult | Exception]:
|
||||
"""Execute a single branch with retry logic."""
|
||||
node_spec = graph.get_node(branch.node_id)
|
||||
if node_spec is None:
|
||||
branch.status = "failed"
|
||||
branch.error = f"Node {branch.node_id} not found in graph"
|
||||
return branch, RuntimeError(branch.error)
|
||||
branch.status = "running"
|
||||
|
||||
try:
|
||||
# Validate and clean output before mapping inputs (same as _follow_edges)
|
||||
if self.cleansing_config.enabled and node_spec:
|
||||
validation = self.output_cleaner.validate_output(
|
||||
output=source_result.output,
|
||||
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||
target_node_spec=node_spec,
|
||||
)
|
||||
|
||||
if not validation.valid:
|
||||
self.logger.warning(
|
||||
f"⚠ Output validation failed for branch "
|
||||
f"{branch.node_id}: {validation.errors}"
|
||||
)
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
output=source_result.output,
|
||||
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||
target_node_spec=node_spec,
|
||||
validation_errors=validation.errors,
|
||||
)
|
||||
# Write cleaned output to memory
|
||||
for key, value in cleaned_output.items():
|
||||
await memory.write_async(key, value)
|
||||
|
||||
# Map inputs via edge
|
||||
mapped = branch.edge.map_inputs(source_result.output, memory.read_all())
|
||||
for key, value in mapped.items():
|
||||
await memory.write_async(key, value)
|
||||
|
||||
# Execute with retries
|
||||
last_result = None
|
||||
for attempt in range(node_spec.max_retries):
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
ctx = self._build_context(node_spec, memory, goal, mapped, graph.max_tokens)
|
||||
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||
|
||||
self.logger.info(
|
||||
f" ▶ Branch {node_spec.name}: executing (attempt {attempt + 1})"
|
||||
)
|
||||
result = await node_impl.execute(ctx)
|
||||
last_result = result
|
||||
|
||||
if result.success:
|
||||
# Write outputs to shared memory using async write
|
||||
for key, value in result.output.items():
|
||||
await memory.write_async(key, value)
|
||||
|
||||
branch.result = result
|
||||
branch.status = "completed"
|
||||
self.logger.info(
|
||||
f" ✓ Branch {node_spec.name}: success "
|
||||
f"(tokens: {result.tokens_used}, latency: {result.latency_ms}ms)"
|
||||
)
|
||||
return branch, result
|
||||
|
||||
self.logger.warning(
|
||||
f" ↻ Branch {node_spec.name}: "
|
||||
f"retry {attempt + 1}/{node_spec.max_retries}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
branch.status = "failed"
|
||||
branch.error = last_result.error if last_result else "Unknown error"
|
||||
branch.result = last_result
|
||||
self.logger.error(
|
||||
f" ✗ Branch {node_spec.name}: "
|
||||
f"failed after {node_spec.max_retries} attempts"
|
||||
)
|
||||
return branch, last_result
|
||||
|
||||
except Exception as e:
|
||||
branch.status = "failed"
|
||||
branch.error = str(e)
|
||||
self.logger.error(f" ✗ Branch {branch.node_id}: exception - {e}")
|
||||
return branch, e
|
||||
|
||||
# Execute all branches concurrently
|
||||
tasks = [execute_single_branch(b) for b in branches.values()]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# Process results
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
branch_results: dict[str, NodeResult] = {}
|
||||
failed_branches: list[ParallelBranch] = []
|
||||
|
||||
for branch, result in results:
|
||||
path.append(branch.node_id)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
failed_branches.append(branch)
|
||||
elif result is None or not result.success:
|
||||
failed_branches.append(branch)
|
||||
else:
|
||||
total_tokens += result.tokens_used
|
||||
total_latency += result.latency_ms
|
||||
branch_results[branch.branch_id] = result
|
||||
|
||||
# Handle failures based on config
|
||||
if failed_branches:
|
||||
failed_names = [graph.get_node(b.node_id).name for b in failed_branches]
|
||||
if self._parallel_config.on_branch_failure == "fail_all":
|
||||
raise RuntimeError(f"Parallel execution failed: branches {failed_names} failed")
|
||||
elif self._parallel_config.on_branch_failure == "continue_others":
|
||||
self.logger.warning(
|
||||
f"⚠ Some branches failed ({failed_names}), continuing with successful ones"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f" ⑃ Fan-out complete: {len(branch_results)}/{len(branches)} branches succeeded"
|
||||
)
|
||||
return branch_results, total_tokens, total_latency
|
||||
|
||||
def register_node(self, node_id: str, implementation: NodeProtocol) -> None:
|
||||
"""Register a custom node implementation."""
|
||||
self.node_registry[node_id] = implementation
|
||||
|
||||
@@ -167,7 +167,10 @@ class FlexibleGraphExecutor:
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=("No executable steps available but plan not complete. Check dependencies."),
|
||||
feedback=(
|
||||
"No executable steps available but plan not complete. "
|
||||
"Check dependencies."
|
||||
),
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
@@ -363,7 +366,10 @@ class FlexibleGraphExecutor:
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=(f"Step '{step.id}' failed after {step.attempts} attempts: {judgment.feedback}"),
|
||||
feedback=(
|
||||
f"Step '{step.id}' failed after {step.attempts} attempts: "
|
||||
f"{judgment.feedback}"
|
||||
),
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
@@ -459,10 +465,11 @@ class FlexibleGraphExecutor:
|
||||
args_preview = args_preview[:500] + "..."
|
||||
preview_parts.append(f"Args: {args_preview}")
|
||||
elif step.action.prompt:
|
||||
if len(step.action.prompt) > 300:
|
||||
prompt_preview = step.action.prompt[:300] + "..."
|
||||
else:
|
||||
prompt_preview = step.action.prompt
|
||||
prompt_preview = (
|
||||
step.action.prompt[:300] + "..."
|
||||
if len(step.action.prompt) > 300
|
||||
else step.action.prompt
|
||||
)
|
||||
preview_parts.append(f"Prompt: {prompt_preview}")
|
||||
|
||||
# Include step inputs resolved from context (what will be sent/used)
|
||||
|
||||
@@ -41,7 +41,9 @@ class SuccessCriterion(BaseModel):
|
||||
|
||||
id: str
|
||||
description: str = Field(description="Human-readable description of what success looks like")
|
||||
metric: str = Field(description="How to measure: 'output_contains', 'output_equals', 'llm_judge', 'custom'")
|
||||
metric: str = Field(
|
||||
description="How to measure: 'output_contains', 'output_equals', 'llm_judge', 'custom'"
|
||||
)
|
||||
target: Any = Field(description="The target value or condition")
|
||||
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="Relative importance (0-1)")
|
||||
met: bool = False
|
||||
@@ -60,9 +62,15 @@ class Constraint(BaseModel):
|
||||
|
||||
id: str
|
||||
description: str
|
||||
constraint_type: str = Field(description="Type: 'hard' (must not violate) or 'soft' (prefer not to violate)")
|
||||
category: str = Field(default="general", description="Category: 'time', 'cost', 'safety', 'scope', 'quality'")
|
||||
check: str = Field(default="", description="How to check: expression, function name, or 'llm_judge'")
|
||||
constraint_type: str = Field(
|
||||
description="Type: 'hard' (must not violate) or 'soft' (prefer not to violate)"
|
||||
)
|
||||
category: str = Field(
|
||||
default="general", description="Category: 'time', 'cost', 'safety', 'scope', 'quality'"
|
||||
)
|
||||
check: str = Field(
|
||||
default="", description="How to check: expression, function name, or 'llm_judge'"
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@@ -129,7 +137,9 @@ class Goal(BaseModel):
|
||||
|
||||
# Input/output schema
|
||||
input_schema: dict[str, Any] = Field(default_factory=dict, description="Expected input format")
|
||||
output_schema: dict[str, Any] = Field(default_factory=dict, description="Expected output format")
|
||||
output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Expected output format"
|
||||
)
|
||||
|
||||
# Versioning for evolution
|
||||
version: str = "1.0.0"
|
||||
|
||||
@@ -178,7 +178,9 @@ class HITLProtocol:
|
||||
|
||||
import anthropic
|
||||
|
||||
questions_str = "\n".join([f"{i + 1}. {q.question} (id: {q.id})" for i, q in enumerate(request.questions)])
|
||||
questions_str = "\n".join(
|
||||
[f"{i + 1}. {q.question} (id: {q.id})" for i, q in enumerate(request.questions)]
|
||||
)
|
||||
|
||||
prompt = f"""Parse the user's response and extract answers for each question.
|
||||
|
||||
|
||||
@@ -218,7 +218,8 @@ class HybridJudge:
|
||||
return Judgment(
|
||||
action=JudgmentAction.ESCALATE,
|
||||
reasoning=(
|
||||
f"LLM confidence ({judgment.confidence:.2f}) below threshold ({self.llm_confidence_threshold})"
|
||||
f"LLM confidence ({judgment.confidence:.2f}) "
|
||||
f"below threshold ({self.llm_confidence_threshold})"
|
||||
),
|
||||
feedback=judgment.feedback,
|
||||
confidence=judgment.confidence,
|
||||
@@ -357,7 +358,8 @@ def create_default_judge(llm: LLMProvider | None = None) -> HybridJudge:
|
||||
id="transient_error_retry",
|
||||
description="Transient error that can be retried",
|
||||
condition=(
|
||||
"isinstance(result, dict) and result.get('error_type') in ['timeout', 'rate_limit', 'connection_error']"
|
||||
"isinstance(result, dict) and "
|
||||
"result.get('error_type') in ['timeout', 'rate_limit', 'connection_error']"
|
||||
),
|
||||
action=JudgmentAction.RETRY,
|
||||
feedback_template="Transient error: {result[error]}. Please retry.",
|
||||
|
||||
+975
-114
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,392 @@
|
||||
"""
|
||||
Output Cleaner - Framework-level I/O validation and cleaning.
|
||||
|
||||
Validates node outputs match expected schemas and uses fast LLM
|
||||
to clean malformed outputs before they flow to the next node.
|
||||
|
||||
This prevents cascading failures and dramatically improves execution success rates.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _heuristic_repair(text: str) -> dict | None:
|
||||
"""
|
||||
Attempt to repair JSON without an LLM call.
|
||||
|
||||
Handles common errors:
|
||||
- Markdown code blocks
|
||||
- Python booleans/None (True -> true)
|
||||
- Single quotes instead of double quotes
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
|
||||
# 1. Strip Markdown code blocks
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\s*```$", "", text, flags=re.MULTILINE)
|
||||
text = text.strip()
|
||||
|
||||
# 2. Find outermost JSON-like structure (greedy match)
|
||||
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
|
||||
if match:
|
||||
candidate = match.group(1)
|
||||
|
||||
# 3. Common fixes
|
||||
# Fix Python constants
|
||||
candidate = re.sub(r"\bTrue\b", "true", candidate)
|
||||
candidate = re.sub(r"\bFalse\b", "false", candidate)
|
||||
candidate = re.sub(r"\bNone\b", "null", candidate)
|
||||
|
||||
# 4. Attempt load
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
# 5. Advanced: Try swapping single quotes if double quotes fail
|
||||
# This is risky but effective for simple dicts
|
||||
try:
|
||||
if "'" in candidate and '"' not in candidate:
|
||||
candidate_swapped = candidate.replace("'", '"')
|
||||
return json.loads(candidate_swapped)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CleansingConfig:
|
||||
"""Configuration for output cleansing."""
|
||||
|
||||
enabled: bool = True
|
||||
fast_model: str = "cerebras/llama-3.3-70b" # Fast, cheap model for cleaning
|
||||
max_retries: int = 2
|
||||
cache_successful_patterns: bool = True
|
||||
fallback_to_raw: bool = True # If cleaning fails, pass raw output
|
||||
log_cleanings: bool = True # Log when cleansing happens
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of output validation."""
|
||||
|
||||
valid: bool
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
cleaned_output: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OutputCleaner:
|
||||
"""
|
||||
Framework-level output validation and cleaning.
|
||||
|
||||
Uses heuristics and fast LLM to clean malformed outputs
|
||||
before they flow to the next node.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CleansingConfig, llm_provider=None):
|
||||
"""
|
||||
Initialize the output cleaner.
|
||||
|
||||
Args:
|
||||
config: Cleansing configuration
|
||||
llm_provider: Optional LLM provider.
|
||||
"""
|
||||
self.config = config
|
||||
self.success_cache: dict[str, Any] = {} # Cache successful patterns
|
||||
self.failure_count: dict[str, int] = {} # Track edge failures
|
||||
self.cleansing_count = 0 # Track total cleanings performed
|
||||
|
||||
# Initialize LLM provider for cleaning
|
||||
if llm_provider:
|
||||
self.llm = llm_provider
|
||||
elif config.enabled:
|
||||
# Create dedicated fast LLM provider for cleaning
|
||||
try:
|
||||
import os
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
api_key = os.environ.get("CEREBRAS_API_KEY")
|
||||
if api_key:
|
||||
self.llm = LiteLLMProvider(
|
||||
api_key=api_key,
|
||||
model=config.fast_model,
|
||||
)
|
||||
logger.info(f"✓ Initialized OutputCleaner with {config.fast_model}")
|
||||
else:
|
||||
logger.warning("⚠ CEREBRAS_API_KEY not found, output cleaning will be disabled")
|
||||
self.llm = None
|
||||
except ImportError:
|
||||
logger.warning("⚠ LiteLLMProvider not available, output cleaning disabled")
|
||||
self.llm = None
|
||||
else:
|
||||
self.llm = None
|
||||
|
||||
def validate_output(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
source_node_id: str,
|
||||
target_node_spec: Any, # NodeSpec
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate output matches target node's expected input schema.
|
||||
|
||||
Returns:
|
||||
ValidationResult with errors and optionally cleaned output
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# Check 1: Required input keys present
|
||||
for key in target_node_spec.input_keys:
|
||||
if key not in output:
|
||||
errors.append(f"Missing required key: '{key}'")
|
||||
continue
|
||||
|
||||
value = output[key]
|
||||
|
||||
# Check 2: Detect if value is JSON string (the JSON parsing trap!)
|
||||
if isinstance(value, str):
|
||||
# Try parsing as JSON to detect the trap
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if isinstance(parsed, dict):
|
||||
if key in parsed:
|
||||
# Key exists in parsed JSON - classic parsing failure!
|
||||
errors.append(
|
||||
f"Key '{key}' contains JSON string with nested '{key}' field - "
|
||||
f"likely parsing failure from LLM node"
|
||||
)
|
||||
elif len(value) > 100:
|
||||
# Large JSON string, but doesn't contain the key
|
||||
warnings.append(
|
||||
f"Key '{key}' contains JSON string ({len(value)} chars)"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON, check if suspiciously large
|
||||
if len(value) > 500:
|
||||
warnings.append(
|
||||
f"Key '{key}' contains large string ({len(value)} chars), "
|
||||
f"possibly entire LLM response"
|
||||
)
|
||||
|
||||
# Check 3: Type validation (if schema provided)
|
||||
if hasattr(target_node_spec, "input_schema") and target_node_spec.input_schema:
|
||||
expected_schema = target_node_spec.input_schema.get(key)
|
||||
if expected_schema:
|
||||
expected_type = expected_schema.get("type")
|
||||
if expected_type and not self._type_matches(value, expected_type):
|
||||
actual_type = type(value).__name__
|
||||
errors.append(
|
||||
f"Key '{key}': expected type '{expected_type}', got '{actual_type}'"
|
||||
)
|
||||
|
||||
# Warnings don't make validation fail, but errors do
|
||||
is_valid = len(errors) == 0
|
||||
|
||||
if not is_valid and self.config.log_cleanings:
|
||||
logger.warning(
|
||||
f"⚠ Output validation failed for {source_node_id} → {target_node_spec.id}: "
|
||||
f"{len(errors)} error(s), {len(warnings)} warning(s)"
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def clean_output(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
source_node_id: str,
|
||||
target_node_spec: Any, # NodeSpec
|
||||
validation_errors: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Use heuristics and fast LLM to clean malformed output.
|
||||
|
||||
Args:
|
||||
output: Raw output from source node
|
||||
source_node_id: ID of source node
|
||||
target_node_spec: Target node spec (for schema)
|
||||
validation_errors: Errors from validation
|
||||
|
||||
Returns:
|
||||
Cleaned output matching target schema
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
logger.warning("⚠ Output cleansing disabled in config")
|
||||
return output
|
||||
|
||||
# --- PHASE 1: Fast Heuristic Repair (Avoids LLM call) ---
|
||||
# Often the output is just a string containing JSON, or has minor syntax errors
|
||||
# If output is a dictionary but malformed, we might need to serialize it first
|
||||
# to try and fix the underlying string representation if it came from raw text
|
||||
|
||||
# Heuristic: Check if any value is actually a JSON string that should be promoted
|
||||
# This handles the "JSON Parsing Trap" where LLM returns {"key": "{\"nested\": ...}"}
|
||||
heuristic_fixed = False
|
||||
fixed_output = output.copy()
|
||||
|
||||
for key, value in output.items():
|
||||
if isinstance(value, str):
|
||||
repaired = _heuristic_repair(value)
|
||||
if repaired and isinstance(repaired, dict | list):
|
||||
# Check if this repaired structure looks like what we want
|
||||
# e.g. if the key is 'data' and the string contained valid JSON
|
||||
fixed_output[key] = repaired
|
||||
heuristic_fixed = True
|
||||
|
||||
# If we fixed something, re-validate manually to see if it's enough
|
||||
if heuristic_fixed:
|
||||
logger.info("⚡ Heuristic repair applied (nested JSON expansion)")
|
||||
return fixed_output
|
||||
|
||||
# --- PHASE 2: LLM-based Repair ---
|
||||
if not self.llm:
|
||||
logger.warning("⚠ No LLM provider available for cleansing")
|
||||
return output
|
||||
|
||||
# Build schema description for target node
|
||||
schema_desc = self._build_schema_description(target_node_spec)
|
||||
|
||||
# Create cleansing prompt
|
||||
prompt = f"""Clean this malformed agent output to match the expected schema.
|
||||
|
||||
VALIDATION ERRORS:
|
||||
{chr(10).join(f"- {e}" for e in validation_errors)}
|
||||
|
||||
EXPECTED SCHEMA for node '{target_node_spec.id}':
|
||||
{schema_desc}
|
||||
|
||||
RAW OUTPUT from node '{source_node_id}':
|
||||
{json.dumps(output, indent=2, default=str)}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Extract values that match the expected schema keys
|
||||
2. If a value is a JSON string, parse it and extract the correct field
|
||||
3. Convert types to match the schema (string, dict, list, number, boolean)
|
||||
4. Remove extra fields not in the expected schema
|
||||
5. Ensure all required keys are present
|
||||
|
||||
Return ONLY valid JSON matching the expected schema. No explanations, no markdown."""
|
||||
|
||||
try:
|
||||
if self.config.log_cleanings:
|
||||
logger.info(
|
||||
f"🧹 Cleaning output from '{source_node_id}' using {self.config.fast_model}"
|
||||
)
|
||||
|
||||
response = self.llm.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system=(
|
||||
"You clean malformed agent outputs. Return only valid JSON matching the schema."
|
||||
),
|
||||
max_tokens=2048, # Sufficient for cleaning most outputs
|
||||
)
|
||||
|
||||
# Parse cleaned output
|
||||
cleaned_text = response.content.strip()
|
||||
|
||||
# Apply heuristic repair to the LLM's output too (just in case)
|
||||
cleaned = _heuristic_repair(cleaned_text)
|
||||
|
||||
if not cleaned:
|
||||
# Fallback to standard load if heuristic returns None (unlikely for LLM output)
|
||||
cleaned = json.loads(cleaned_text)
|
||||
|
||||
if isinstance(cleaned, dict):
|
||||
self.cleansing_count += 1
|
||||
if self.config.log_cleanings:
|
||||
logger.info(
|
||||
f"✓ Output cleaned successfully (total cleanings: {self.cleansing_count})"
|
||||
)
|
||||
return cleaned
|
||||
else:
|
||||
logger.warning(f"⚠ Cleaned output is not a dict: {type(cleaned)}")
|
||||
if self.config.fallback_to_raw:
|
||||
return output
|
||||
else:
|
||||
raise ValueError(f"Cleaning produced {type(cleaned)}, expected dict")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"✗ Failed to parse cleaned JSON: {e}")
|
||||
if self.config.fallback_to_raw:
|
||||
logger.info("↩ Falling back to raw output")
|
||||
return output
|
||||
else:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Output cleaning failed: {e}")
|
||||
if self.config.fallback_to_raw:
|
||||
logger.info("↩ Falling back to raw output")
|
||||
return output
|
||||
else:
|
||||
raise
|
||||
|
||||
def _build_schema_description(self, node_spec: Any) -> str:
|
||||
"""Build human-readable schema description from NodeSpec."""
|
||||
lines = ["{"]
|
||||
|
||||
for key in node_spec.input_keys:
|
||||
# Get type hint and description if available
|
||||
if hasattr(node_spec, "input_schema") and node_spec.input_schema:
|
||||
schema = node_spec.input_schema.get(key, {})
|
||||
type_hint = schema.get("type", "any")
|
||||
description = schema.get("description", "")
|
||||
required = schema.get("required", True)
|
||||
|
||||
line = f' "{key}": {type_hint}'
|
||||
if description:
|
||||
line += f" // {description}"
|
||||
if required:
|
||||
line += " (required)"
|
||||
lines.append(line + ",")
|
||||
else:
|
||||
# No schema, just show the key
|
||||
lines.append(f' "{key}": any // (required)')
|
||||
|
||||
lines.append("}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _type_matches(self, value: Any, expected_type: str) -> bool:
|
||||
"""Check if value matches expected type."""
|
||||
type_map = {
|
||||
"string": str,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"number": (int, float),
|
||||
"bool": bool,
|
||||
"boolean": bool,
|
||||
"dict": dict,
|
||||
"object": dict,
|
||||
"list": list,
|
||||
"array": list,
|
||||
"any": object, # Matches everything
|
||||
}
|
||||
|
||||
expected_class = type_map.get(expected_type.lower())
|
||||
if expected_class:
|
||||
return isinstance(value, expected_class)
|
||||
|
||||
# Unknown type, allow it
|
||||
return True
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cleansing statistics."""
|
||||
return {
|
||||
"total_cleanings": self.cleansing_count,
|
||||
"failure_count": dict(self.failure_count),
|
||||
"cache_size": len(self.success_cache),
|
||||
}
|
||||
@@ -38,6 +38,23 @@ class StepStatus(str, Enum):
|
||||
SKIPPED = "skipped"
|
||||
REJECTED = "rejected" # Human rejected execution
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
"""Check if this status represents a terminal (finished) state.
|
||||
|
||||
Terminal states are states where the step will not execute further,
|
||||
either because it completed successfully or failed/was skipped.
|
||||
"""
|
||||
return self in (
|
||||
StepStatus.COMPLETED,
|
||||
StepStatus.FAILED,
|
||||
StepStatus.SKIPPED,
|
||||
StepStatus.REJECTED,
|
||||
)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if this status represents successful completion."""
|
||||
return self == StepStatus.COMPLETED
|
||||
|
||||
|
||||
class ApprovalDecision(str, Enum):
|
||||
"""Human decision on a step requiring approval."""
|
||||
@@ -131,14 +148,22 @@ class PlanStep(BaseModel):
|
||||
default_factory=dict,
|
||||
description="Input data for this step (can reference previous step outputs)",
|
||||
)
|
||||
expected_outputs: list[str] = Field(default_factory=list, description="Keys this step should produce")
|
||||
expected_outputs: list[str] = Field(
|
||||
default_factory=list, description="Keys this step should produce"
|
||||
)
|
||||
|
||||
# Dependencies
|
||||
dependencies: list[str] = Field(default_factory=list, description="IDs of steps that must complete before this one")
|
||||
dependencies: list[str] = Field(
|
||||
default_factory=list, description="IDs of steps that must complete before this one"
|
||||
)
|
||||
|
||||
# Human-in-the-loop (HITL)
|
||||
requires_approval: bool = Field(default=False, description="If True, requires human approval before execution")
|
||||
approval_message: str | None = Field(default=None, description="Message to show human when requesting approval")
|
||||
requires_approval: bool = Field(
|
||||
default=False, description="If True, requires human approval before execution"
|
||||
)
|
||||
approval_message: str | None = Field(
|
||||
default=None, description="Message to show human when requesting approval"
|
||||
)
|
||||
|
||||
# Execution state
|
||||
status: StepStatus = StepStatus.PENDING
|
||||
@@ -153,11 +178,23 @@ class PlanStep(BaseModel):
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def is_ready(self, completed_step_ids: set[str]) -> bool:
|
||||
"""Check if this step is ready to execute (all dependencies met)."""
|
||||
def is_ready(self, terminal_step_ids: set[str]) -> bool:
|
||||
"""Check if this step is ready to execute (all dependencies finished).
|
||||
|
||||
A step is ready when:
|
||||
1. Its status is PENDING (not yet started)
|
||||
2. All its dependencies are in a terminal state (completed, failed, skipped, or rejected)
|
||||
|
||||
Note: This allows dependent steps to become "ready" even if their dependencies
|
||||
failed. The executor should check if any dependencies failed and handle
|
||||
accordingly (e.g., skip the step or mark it as blocked).
|
||||
|
||||
Args:
|
||||
terminal_step_ids: Set of step IDs that are in a terminal state
|
||||
"""
|
||||
if self.status != StepStatus.PENDING:
|
||||
return False
|
||||
return all(dep in completed_step_ids for dep in self.dependencies)
|
||||
return all(dep in terminal_step_ids for dep in self.dependencies)
|
||||
|
||||
|
||||
class Judgment(BaseModel):
|
||||
@@ -319,18 +356,46 @@ class Plan(BaseModel):
|
||||
return None
|
||||
|
||||
def get_ready_steps(self) -> list[PlanStep]:
|
||||
"""Get all steps that are ready to execute."""
|
||||
completed_ids = {s.id for s in self.steps if s.status == StepStatus.COMPLETED}
|
||||
return [s for s in self.steps if s.is_ready(completed_ids)]
|
||||
"""Get all steps that are ready to execute.
|
||||
|
||||
A step is ready when all its dependencies are in terminal states
|
||||
(completed, failed, skipped, or rejected).
|
||||
"""
|
||||
terminal_ids = {s.id for s in self.steps if s.status.is_terminal()}
|
||||
return [s for s in self.steps if s.is_ready(terminal_ids)]
|
||||
|
||||
def get_completed_steps(self) -> list[PlanStep]:
|
||||
"""Get all completed steps."""
|
||||
return [s for s in self.steps if s.status == StepStatus.COMPLETED]
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all steps are completed."""
|
||||
"""Check if all steps are in terminal states (finished executing).
|
||||
|
||||
Returns True when all steps have reached a terminal state, regardless
|
||||
of whether they succeeded or failed. Use has_failed_steps() to check
|
||||
if any steps failed.
|
||||
"""
|
||||
return all(s.status.is_terminal() for s in self.steps)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if all steps completed successfully."""
|
||||
return all(s.status == StepStatus.COMPLETED for s in self.steps)
|
||||
|
||||
def has_failed_steps(self) -> bool:
|
||||
"""Check if any steps failed, were skipped, or were rejected."""
|
||||
return any(
|
||||
s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
|
||||
for s in self.steps
|
||||
)
|
||||
|
||||
def get_failed_steps(self) -> list[PlanStep]:
|
||||
"""Get all steps that failed, were skipped, or were rejected."""
|
||||
return [
|
||||
s
|
||||
for s in self.steps
|
||||
if s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
|
||||
]
|
||||
|
||||
def to_feedback_context(self) -> dict[str, Any]:
|
||||
"""Create context for replanning."""
|
||||
return {
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
# Safe operators whitelist
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
ast.LShift: operator.lshift,
|
||||
ast.RShift: operator.rshift,
|
||||
ast.BitOr: operator.or_,
|
||||
ast.BitXor: operator.xor,
|
||||
ast.BitAnd: operator.and_,
|
||||
ast.Eq: operator.eq,
|
||||
ast.NotEq: operator.ne,
|
||||
ast.Lt: operator.lt,
|
||||
ast.LtE: operator.le,
|
||||
ast.Gt: operator.gt,
|
||||
ast.GtE: operator.ge,
|
||||
ast.Is: operator.is_,
|
||||
ast.IsNot: operator.is_not,
|
||||
ast.In: lambda x, y: x in y,
|
||||
ast.NotIn: lambda x, y: x not in y,
|
||||
ast.USub: operator.neg,
|
||||
ast.UAdd: operator.pos,
|
||||
ast.Not: operator.not_,
|
||||
ast.Invert: operator.inv,
|
||||
}
|
||||
|
||||
# Safe functions whitelist
|
||||
SAFE_FUNCTIONS = {
|
||||
"len": len,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"all": all,
|
||||
"any": any,
|
||||
}
|
||||
|
||||
|
||||
class SafeEvalVisitor(ast.NodeVisitor):
|
||||
def __init__(self, context: dict[str, Any]):
|
||||
self.context = context
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
# Override visit to prevent default behavior and ensure only explicitly allowed nodes work
|
||||
method = "visit_" + node.__class__.__name__
|
||||
visitor = getattr(self, method, self.generic_visit)
|
||||
return visitor(node)
|
||||
|
||||
def generic_visit(self, node: ast.AST):
|
||||
raise ValueError(f"Use of {node.__class__.__name__} is not allowed")
|
||||
|
||||
def visit_Expression(self, node: ast.Expression) -> Any:
|
||||
return self.visit(node.body)
|
||||
|
||||
def visit_Expr(self, node: ast.Expr) -> Any:
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_Constant(self, node: ast.Constant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Data Structures ---
|
||||
def visit_List(self, node: ast.List) -> list:
|
||||
return [self.visit(elt) for elt in node.elts]
|
||||
|
||||
def visit_Tuple(self, node: ast.Tuple) -> tuple:
|
||||
return tuple(self.visit(elt) for elt in node.elts)
|
||||
|
||||
def visit_Dict(self, node: ast.Dict) -> dict:
|
||||
return {
|
||||
self.visit(k): self.visit(v)
|
||||
for k, v in zip(node.keys, node.values, strict=False)
|
||||
if k is not None
|
||||
}
|
||||
|
||||
# --- Operations ---
|
||||
def visit_BinOp(self, node: ast.BinOp) -> Any:
|
||||
op_func = SAFE_OPERATORS.get(type(node.op))
|
||||
if op_func is None:
|
||||
raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
|
||||
return op_func(self.visit(node.left), self.visit(node.right))
|
||||
|
||||
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
|
||||
op_func = SAFE_OPERATORS.get(type(node.op))
|
||||
if op_func is None:
|
||||
raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
|
||||
return op_func(self.visit(node.operand))
|
||||
|
||||
def visit_Compare(self, node: ast.Compare) -> Any:
|
||||
left = self.visit(node.left)
|
||||
for op, comparator in zip(node.ops, node.comparators, strict=False):
|
||||
op_func = SAFE_OPERATORS.get(type(op))
|
||||
if op_func is None:
|
||||
raise ValueError(f"Operator {type(op).__name__} is not allowed")
|
||||
right = self.visit(comparator)
|
||||
if not op_func(left, right):
|
||||
return False
|
||||
left = right # Chain comparisons
|
||||
return True
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
|
||||
values = [self.visit(v) for v in node.values]
|
||||
if isinstance(node.op, ast.And):
|
||||
return all(values)
|
||||
elif isinstance(node.op, ast.Or):
|
||||
return any(values)
|
||||
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
|
||||
|
||||
def visit_IfExp(self, node: ast.IfExp) -> Any:
|
||||
# Ternary: true_val if test else false_val
|
||||
if self.visit(node.test):
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
return self.visit(node.orelse)
|
||||
|
||||
# --- Variables and Attributes ---
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
if node.id in self.context:
|
||||
return self.context[node.id]
|
||||
raise NameError(f"Name '{node.id}' is not defined")
|
||||
raise ValueError("Only reading variables is allowed")
|
||||
|
||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||
# value[slice]
|
||||
val = self.visit(node.value)
|
||||
idx = self.visit(node.slice)
|
||||
return val[idx]
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
# value.attr
|
||||
# STIRCT CHECK: No access to private attributes (starting with _)
|
||||
if node.attr.startswith("_"):
|
||||
raise ValueError(f"Access to private attribute '{node.attr}' is not allowed")
|
||||
|
||||
val = self.visit(node.value)
|
||||
|
||||
# Safe attribute access: only allow if it's in the dict (if val is dict)
|
||||
# or it's a safe property of a basic type?
|
||||
# Actually, for flexibility, people often use dot access for dicts in these expressions.
|
||||
# But standard Python dict doesn't support dot access.
|
||||
# If val is a dict, Attribute access usually fails in Python unless wrapped.
|
||||
# If the user context provides objects, we might want to allow attribute access.
|
||||
# BUT we must be careful not to allow access to dangerous things like __class__ etc.
|
||||
# The check starts_with("_") covers __class__, __init__, etc.
|
||||
|
||||
try:
|
||||
return getattr(val, node.attr)
|
||||
except AttributeError:
|
||||
# Fallback: maybe it's a dict and they want dot access?
|
||||
# (Only if we want to support that sugar, usually not standard python)
|
||||
# Let's stick to standard python behavior + strict private check.
|
||||
pass
|
||||
|
||||
raise AttributeError(f"Object has no attribute '{node.attr}'")
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
# Only allow calling whitelisted functions
|
||||
func = self.visit(node.func)
|
||||
|
||||
# Check if the function object itself is in our whitelist values
|
||||
# This is tricky because `func` is the actual function object,
|
||||
# but we also want to verify it came from a safe place.
|
||||
# Easier: Check if node.func is a Name and that name is in SAFE_FUNCTIONS.
|
||||
|
||||
is_safe = False
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in SAFE_FUNCTIONS:
|
||||
is_safe = True
|
||||
|
||||
# Also allow methods on objects if they are safe?
|
||||
# E.g. "somestring".lower() or list.append() (if we allowed mutation, but we don't for now)
|
||||
# For now, restrict to SAFE_FUNCTIONS whitelist for global calls and deny method calls
|
||||
# unless we explicitly add safe methods.
|
||||
# Allowing method calls on strings/lists (split, join, get) is commonly needed.
|
||||
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
# Method call.
|
||||
# Allow basic safe methods?
|
||||
# For security, start strict. Only helper functions.
|
||||
# Re-visiting: User might want 'output.get("key")'.
|
||||
method_name = node.func.attr
|
||||
if method_name in [
|
||||
"get",
|
||||
"keys",
|
||||
"values",
|
||||
"items",
|
||||
"lower",
|
||||
"upper",
|
||||
"strip",
|
||||
"split",
|
||||
]:
|
||||
is_safe = True
|
||||
|
||||
if not is_safe and func not in SAFE_FUNCTIONS.values():
|
||||
raise ValueError("Call to function/method is not allowed")
|
||||
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords}
|
||||
|
||||
return func(*args, **keywords)
|
||||
|
||||
def visit_Index(self, node: ast.Index) -> Any:
|
||||
# Python < 3.9
|
||||
return self.visit(node.value)
|
||||
|
||||
|
||||
def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
|
||||
"""
|
||||
Safely evaluate a python expression string.
|
||||
|
||||
Args:
|
||||
expr: The expression string to evaluate.
|
||||
context: Dictionary of variables available in the expression.
|
||||
|
||||
Returns:
|
||||
The result of the evaluation.
|
||||
|
||||
Raises:
|
||||
ValueError: If unsafe operations or syntax are detected.
|
||||
SyntaxError: If the expression is invalid Python.
|
||||
"""
|
||||
if context is None:
|
||||
context = {}
|
||||
|
||||
# Add safe builtins to context
|
||||
full_context = context.copy()
|
||||
full_context.update(SAFE_FUNCTIONS)
|
||||
|
||||
try:
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
|
||||
|
||||
visitor = SafeEvalVisitor(full_context)
|
||||
return visitor.visit(tree)
|
||||
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Test OutputCleaner with real Cerebras LLM.
|
||||
|
||||
Demonstrates how OutputCleaner fixes the JSON parsing trap using llama-3.3-70b.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
|
||||
def test_cleaning_with_cerebras():
|
||||
"""Test that cleaning fixes malformed output using Cerebras llama-3.3-70b."""
|
||||
print("\n" + "=" * 80)
|
||||
print("LIVE TEST: Cleaning with Cerebras llama-3.3-70b")
|
||||
print("=" * 80)
|
||||
|
||||
# Get API key
|
||||
api_key = os.environ.get("CEREBRAS_API_KEY")
|
||||
if not api_key:
|
||||
print("\n⚠ Skipping: CEREBRAS_API_KEY not found in environment")
|
||||
return
|
||||
|
||||
# Initialize LLM
|
||||
llm = LiteLLMProvider(
|
||||
api_key=api_key,
|
||||
model="cerebras/llama-3.3-70b",
|
||||
)
|
||||
|
||||
# Initialize cleaner with Cerebras
|
||||
cleaner = OutputCleaner(
|
||||
config=CleansingConfig(
|
||||
enabled=True,
|
||||
fast_model="cerebras/llama-3.3-70b",
|
||||
log_cleanings=True,
|
||||
),
|
||||
llm_provider=llm,
|
||||
)
|
||||
|
||||
# Scenario 1: JSON parsing trap (entire response in one key)
|
||||
print("\n--- Scenario 1: JSON Parsing Trap ---")
|
||||
malformed_output = {
|
||||
"recommendation": (
|
||||
'{\n "approval_decision": "APPROVED",\n "risk_score": 3.5,\n '
|
||||
'"reason": "Standard terms, low risk"\n}'
|
||||
),
|
||||
}
|
||||
|
||||
target_spec = NodeSpec(
|
||||
id="generate-recommendation",
|
||||
name="Generate Recommendation",
|
||||
description="Test",
|
||||
input_keys=["recommendation"],
|
||||
output_keys=["result"],
|
||||
input_schema={
|
||||
"recommendation": {
|
||||
"type": "dict",
|
||||
"required": True,
|
||||
"description": "Recommendation with approval_decision and risk_score",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Validate
|
||||
validation = cleaner.validate_output(
|
||||
output=malformed_output,
|
||||
source_node_id="analyze-contract",
|
||||
target_node_spec=target_spec,
|
||||
)
|
||||
|
||||
print("\nMalformed output:")
|
||||
print(json.dumps(malformed_output, indent=2))
|
||||
print(f"\nValidation errors: {validation.errors}")
|
||||
|
||||
# Clean the output
|
||||
print("\n🧹 Cleaning with Cerebras llama-3.3-70b...")
|
||||
cleaned = cleaner.clean_output(
|
||||
output=malformed_output,
|
||||
source_node_id="analyze-contract",
|
||||
target_node_spec=target_spec,
|
||||
validation_errors=validation.errors,
|
||||
)
|
||||
|
||||
print("\n✓ Cleaned output:")
|
||||
print(json.dumps(cleaned, indent=2))
|
||||
|
||||
assert isinstance(cleaned, dict), "Should return dict"
|
||||
assert "approval_decision" in str(cleaned) or isinstance(cleaned.get("recommendation"), dict), (
|
||||
"Should have recommendation structure"
|
||||
)
|
||||
|
||||
# Scenario 2: Multiple keys with JSON string
|
||||
print("\n\n--- Scenario 2: Multiple Keys, JSON String ---")
|
||||
malformed_output2 = {
|
||||
"analysis": (
|
||||
'{"high_risk_clauses": ["unlimited liability"], '
|
||||
'"compliance_issues": [], "category": "high-risk"}'
|
||||
),
|
||||
"risk_score": "7.5", # String instead of number
|
||||
}
|
||||
|
||||
target_spec2 = NodeSpec(
|
||||
id="next-node",
|
||||
name="Next Node",
|
||||
description="Test",
|
||||
input_keys=["analysis", "risk_score"],
|
||||
output_keys=["result"],
|
||||
input_schema={
|
||||
"analysis": {"type": "dict", "required": True},
|
||||
"risk_score": {"type": "number", "required": True},
|
||||
},
|
||||
)
|
||||
|
||||
validation2 = cleaner.validate_output(
|
||||
output=malformed_output2,
|
||||
source_node_id="analyze",
|
||||
target_node_spec=target_spec2,
|
||||
)
|
||||
|
||||
print("\nMalformed output:")
|
||||
print(json.dumps(malformed_output2, indent=2))
|
||||
print(f"\nValidation errors: {validation2.errors}")
|
||||
|
||||
if not validation2.valid:
|
||||
print("\n🧹 Cleaning with Cerebras llama-3.3-70b...")
|
||||
cleaned2 = cleaner.clean_output(
|
||||
output=malformed_output2,
|
||||
source_node_id="analyze",
|
||||
target_node_spec=target_spec2,
|
||||
validation_errors=validation2.errors,
|
||||
)
|
||||
|
||||
print("\n✓ Cleaned output:")
|
||||
print(json.dumps(cleaned2, indent=2))
|
||||
|
||||
assert isinstance(cleaned2, dict), "Should return dict"
|
||||
assert isinstance(cleaned2.get("analysis"), dict), "analysis should be dict"
|
||||
assert isinstance(cleaned2.get("risk_score"), (int, float)), "risk_score should be number"
|
||||
|
||||
# Stats
|
||||
stats = cleaner.get_stats()
|
||||
print("\n\nCleaner Statistics:")
|
||||
print(f" Total cleanings: {stats['total_cleanings']}")
|
||||
print(f" Cache size: {stats['cache_size']}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("✓ LIVE TEST PASSED")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def test_validation_only():
|
||||
"""Test validation without LLM (no cleaning)."""
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: Validation Only (No LLM)")
|
||||
print("=" * 80)
|
||||
|
||||
cleaner = OutputCleaner(
|
||||
config=CleansingConfig(enabled=True),
|
||||
llm_provider=None, # No LLM
|
||||
)
|
||||
|
||||
# Test 1: JSON parsing trap detection
|
||||
malformed = {
|
||||
"approval_decision": '{"approval_decision": "APPROVED", "risk_score": 3}',
|
||||
}
|
||||
|
||||
target = NodeSpec(
|
||||
id="target",
|
||||
name="Target",
|
||||
description="Test",
|
||||
input_keys=["approval_decision"],
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
result = cleaner.validate_output(
|
||||
output=malformed,
|
||||
source_node_id="source",
|
||||
target_node_spec=target,
|
||||
)
|
||||
|
||||
print(f"\nInput: {json.dumps(malformed, indent=2)}")
|
||||
print(f"Errors: {result.errors}")
|
||||
print(f"Warnings: {result.warnings}")
|
||||
assert not result.valid or len(result.warnings) > 0, "Should detect JSON string"
|
||||
print("✓ Detected JSON parsing trap")
|
||||
|
||||
# Test 2: Missing keys
|
||||
malformed2 = {"field1": "value"}
|
||||
|
||||
target2 = NodeSpec(
|
||||
id="target",
|
||||
name="Target",
|
||||
description="Test",
|
||||
input_keys=["field1", "field2"],
|
||||
output_keys=["result"],
|
||||
)
|
||||
|
||||
result2 = cleaner.validate_output(
|
||||
output=malformed2,
|
||||
source_node_id="source",
|
||||
target_node_spec=target2,
|
||||
)
|
||||
|
||||
print(f"\nInput: {json.dumps(malformed2, indent=2)}")
|
||||
print(f"Errors: {result2.errors}")
|
||||
assert not result2.valid, "Should be invalid"
|
||||
assert "field2" in result2.errors[0], "Should mention missing field"
|
||||
print("✓ Detected missing keys")
|
||||
|
||||
print("\n✓ Validation tests passed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 80)
|
||||
print("OUTPUT CLEANER LIVE TEST SUITE (with Cerebras)")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
# Test validation (no LLM needed)
|
||||
test_validation_only()
|
||||
|
||||
# Test cleaning with Cerebras
|
||||
test_cleaning_with_cerebras()
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("ALL TESTS PASSED ✓")
|
||||
print("=" * 80)
|
||||
print("\nOutputCleaner is working with Cerebras llama-3.3-70b!")
|
||||
print("- Fast cleaning (~200-500ms per operation)")
|
||||
print("- Fixes JSON parsing trap")
|
||||
print("- Converts types to match schema")
|
||||
print("- Low cost (~$0.001 per cleaning)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ TEST FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
@@ -0,0 +1,312 @@
|
||||
"""Output validation for agent nodes.
|
||||
|
||||
Validates node outputs against schemas and expected keys to prevent
|
||||
garbage from propagating through the graph.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of validating an output."""
|
||||
|
||||
success: bool
|
||||
errors: list[str]
|
||||
|
||||
@property
|
||||
def error(self) -> str:
|
||||
"""Get combined error message."""
|
||||
return "; ".join(self.errors) if self.errors else ""
|
||||
|
||||
|
||||
class OutputValidator:
|
||||
"""
|
||||
Validates node outputs against schemas and expected keys.
|
||||
|
||||
Used by the executor to catch bad outputs before they pollute memory.
|
||||
"""
|
||||
|
||||
def _contains_code_indicators(self, value: str) -> bool:
|
||||
"""
|
||||
Check for code patterns in a string using sampling for efficiency.
|
||||
|
||||
For strings under 10KB, checks the entire content.
|
||||
For longer strings, samples at strategic positions to balance
|
||||
performance with detection accuracy.
|
||||
|
||||
Args:
|
||||
value: The string to check for code indicators
|
||||
|
||||
Returns:
|
||||
True if code indicators are found, False otherwise
|
||||
"""
|
||||
code_indicators = [
|
||||
# Python
|
||||
"def ",
|
||||
"class ",
|
||||
"import ",
|
||||
"from ",
|
||||
"if __name__",
|
||||
"async def ",
|
||||
"await ",
|
||||
"try:",
|
||||
"except:",
|
||||
# JavaScript/TypeScript
|
||||
"function ",
|
||||
"const ",
|
||||
"let ",
|
||||
"=> {",
|
||||
"require(",
|
||||
"export ",
|
||||
# SQL
|
||||
"SELECT ",
|
||||
"INSERT ",
|
||||
"UPDATE ",
|
||||
"DELETE ",
|
||||
"DROP ",
|
||||
# HTML/Script injection
|
||||
"<script",
|
||||
"<?php",
|
||||
"<%",
|
||||
]
|
||||
|
||||
# For strings under 10KB, check the entire content
|
||||
if len(value) < 10000:
|
||||
return any(indicator in value for indicator in code_indicators)
|
||||
|
||||
# For longer strings, sample at strategic positions
|
||||
sample_positions = [
|
||||
0, # Start
|
||||
len(value) // 4, # 25%
|
||||
len(value) // 2, # 50%
|
||||
3 * len(value) // 4, # 75%
|
||||
max(0, len(value) - 2000), # Near end
|
||||
]
|
||||
|
||||
for pos in sample_positions:
|
||||
chunk = value[pos : pos + 2000]
|
||||
if any(indicator in chunk for indicator in code_indicators):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def validate_output_keys(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
expected_keys: list[str],
|
||||
allow_empty: bool = False,
|
||||
nullable_keys: list[str] | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate that all expected keys are present and non-empty.
|
||||
|
||||
Args:
|
||||
output: The output dict to validate
|
||||
expected_keys: Keys that must be present
|
||||
allow_empty: If True, allow empty string values
|
||||
nullable_keys: Keys that are allowed to be None
|
||||
|
||||
Returns:
|
||||
ValidationResult with success status and any errors
|
||||
"""
|
||||
errors = []
|
||||
nullable_keys = nullable_keys or []
|
||||
|
||||
if not isinstance(output, dict):
|
||||
return ValidationResult(
|
||||
success=False, errors=[f"Output is not a dict, got {type(output).__name__}"]
|
||||
)
|
||||
|
||||
for key in expected_keys:
|
||||
if key not in output:
|
||||
errors.append(f"Missing required output key: '{key}'")
|
||||
elif not allow_empty:
|
||||
value = output[key]
|
||||
if value is None:
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is None")
|
||||
elif isinstance(value, str) and len(value.strip()) == 0:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
def validate_with_pydantic(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
model: type[BaseModel],
|
||||
) -> tuple[ValidationResult, BaseModel | None]:
|
||||
"""
|
||||
Validate output against a Pydantic model.
|
||||
|
||||
Args:
|
||||
output: The output dict to validate
|
||||
model: Pydantic model class to validate against
|
||||
|
||||
Returns:
|
||||
Tuple of (ValidationResult, validated_model_instance or None)
|
||||
"""
|
||||
try:
|
||||
validated = model.model_validate(output)
|
||||
return ValidationResult(success=True, errors=[]), validated
|
||||
except ValidationError as e:
|
||||
errors = []
|
||||
for error in e.errors():
|
||||
field_path = ".".join(str(loc) for loc in error["loc"])
|
||||
msg = error["msg"]
|
||||
error_type = error["type"]
|
||||
errors.append(f"{field_path}: {msg} (type: {error_type})")
|
||||
return ValidationResult(success=False, errors=errors), None
|
||||
|
||||
def format_validation_feedback(
|
||||
self,
|
||||
validation_result: ValidationResult,
|
||||
model: type[BaseModel],
|
||||
) -> str:
|
||||
"""
|
||||
Format validation errors as feedback for LLM retry.
|
||||
|
||||
Args:
|
||||
validation_result: The failed validation result
|
||||
model: The Pydantic model that was used for validation
|
||||
|
||||
Returns:
|
||||
Formatted feedback string to include in retry prompt
|
||||
"""
|
||||
# Get the model's JSON schema for reference
|
||||
schema = model.model_json_schema()
|
||||
|
||||
feedback = "Your previous response had validation errors:\n\n"
|
||||
feedback += "ERRORS:\n"
|
||||
for error in validation_result.errors:
|
||||
feedback += f" - {error}\n"
|
||||
|
||||
feedback += "\nEXPECTED SCHEMA:\n"
|
||||
feedback += f" Model: {model.__name__}\n"
|
||||
|
||||
if "properties" in schema:
|
||||
feedback += " Required fields:\n"
|
||||
required = schema.get("required", [])
|
||||
for prop_name, prop_info in schema["properties"].items():
|
||||
req_marker = " (required)" if prop_name in required else ""
|
||||
prop_type = prop_info.get("type", "any")
|
||||
feedback += f" - {prop_name}: {prop_type}{req_marker}\n"
|
||||
|
||||
feedback += "\nPlease fix the errors and respond with valid JSON matching the schema."
|
||||
|
||||
return feedback
|
||||
|
||||
def validate_no_hallucination(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
max_length: int = 10000,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Check for signs of LLM hallucination in output values.
|
||||
|
||||
Detects:
|
||||
- Code blocks where structured data was expected
|
||||
- Overly long values that suggest raw LLM output
|
||||
- Common hallucination patterns
|
||||
|
||||
Args:
|
||||
output: The output dict to validate
|
||||
max_length: Maximum allowed length for string values
|
||||
|
||||
Returns:
|
||||
ValidationResult with success status and any errors
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for key, value in output.items():
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
|
||||
# Check for code patterns in the entire string, not just first 500 chars
|
||||
if self._contains_code_indicators(value):
|
||||
# Could be legitimate, but warn
|
||||
logger.warning(f"Output key '{key}' may contain code - verify this is expected")
|
||||
|
||||
# Check for overly long values
|
||||
if len(value) > max_length:
|
||||
errors.append(
|
||||
f"Output key '{key}' exceeds max length ({len(value)} > {max_length})"
|
||||
)
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
def validate_schema(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
schema: dict[str, Any],
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate output against a JSON schema.
|
||||
|
||||
Args:
|
||||
output: The output dict to validate
|
||||
schema: JSON schema to validate against
|
||||
|
||||
Returns:
|
||||
ValidationResult with success status and any errors
|
||||
"""
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError:
|
||||
logger.warning("jsonschema not installed, skipping schema validation")
|
||||
return ValidationResult(success=True, errors=[])
|
||||
|
||||
errors = []
|
||||
validator = jsonschema.Draft7Validator(schema)
|
||||
|
||||
for error in validator.iter_errors(output):
|
||||
path = ".".join(str(p) for p in error.path) if error.path else "root"
|
||||
errors.append(f"{path}: {error.message}")
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
def validate_all(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
expected_keys: list[str] | None = None,
|
||||
schema: dict[str, Any] | None = None,
|
||||
check_hallucination: bool = True,
|
||||
nullable_keys: list[str] | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Run all applicable validations on output.
|
||||
|
||||
Args:
|
||||
output: The output dict to validate
|
||||
expected_keys: Optional list of required keys
|
||||
schema: Optional JSON schema
|
||||
check_hallucination: Whether to check for hallucination patterns
|
||||
nullable_keys: Keys that are allowed to be None
|
||||
|
||||
Returns:
|
||||
Combined ValidationResult
|
||||
"""
|
||||
all_errors = []
|
||||
|
||||
# Validate keys if provided
|
||||
if expected_keys:
|
||||
result = self.validate_output_keys(output, expected_keys, nullable_keys=nullable_keys)
|
||||
all_errors.extend(result.errors)
|
||||
|
||||
# Validate schema if provided
|
||||
if schema:
|
||||
result = self.validate_schema(output, schema)
|
||||
all_errors.extend(result.errors)
|
||||
|
||||
# Check for hallucination
|
||||
if check_hallucination:
|
||||
result = self.validate_no_hallucination(output)
|
||||
all_errors.extend(result.errors)
|
||||
|
||||
return ValidationResult(success=len(all_errors) == 0, errors=all_errors)
|
||||
@@ -11,6 +11,7 @@ appropriate executor based on action type:
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -26,6 +27,8 @@ from framework.graph.plan import (
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
|
||||
"""
|
||||
@@ -60,15 +63,21 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
|
||||
try:
|
||||
parsed = json.loads(match.strip())
|
||||
return parsed, match.strip()
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"Failed to parse JSON from code block: {e}. "
|
||||
f"Content preview: {match.strip()[:100]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
# No code blocks or parsing failed - try parsing the whole response
|
||||
try:
|
||||
parsed = json.loads(cleaned)
|
||||
return parsed, cleaned
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"Failed to parse entire response as JSON: {e}. Content preview: {cleaned[:100]}..."
|
||||
)
|
||||
|
||||
# Try to find JSON-like content (starts with { or [)
|
||||
json_start_pattern = r"(\{[\s\S]*\}|\[[\s\S]*\])"
|
||||
@@ -78,10 +87,15 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
|
||||
try:
|
||||
parsed = json.loads(match)
|
||||
return parsed, match
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(f"Failed to parse JSON pattern: {e}. Content preview: {match[:100]}...")
|
||||
continue
|
||||
|
||||
# Could not parse as JSON
|
||||
# Could not parse as JSON - log warning
|
||||
logger.warning(
|
||||
f"Could not parse LLM response as JSON after trying all strategies. "
|
||||
f"Response preview: {cleaned[:200]}..."
|
||||
)
|
||||
return None, cleaned
|
||||
|
||||
|
||||
@@ -292,7 +306,7 @@ class WorkerNode:
|
||||
if inputs:
|
||||
context_section = "\n\n--- Context Data ---\n"
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
if isinstance(value, dict | list):
|
||||
context_section += f"{key}: {json.dumps(value, indent=2)}\n"
|
||||
else:
|
||||
context_section += f"{key}: {value}\n"
|
||||
|
||||
@@ -1,7 +1,26 @@
|
||||
"""LLM provider abstraction."""
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "AnthropicProvider", "LiteLLMProvider"]
|
||||
__all__ = ["LLMProvider", "LLMResponse"]
|
||||
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider # noqa: F401
|
||||
|
||||
__all__.append("AnthropicProvider")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: F401
|
||||
|
||||
__all__.append("LiteLLMProvider")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from framework.llm.mock import MockLLMProvider # noqa: F401
|
||||
|
||||
__all__.append("MockLLMProvider")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
"""Anthropic Claude LLM provider - backward compatible wrapper around LiteLLM."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
|
||||
def _get_api_key_from_credential_manager() -> str | None:
|
||||
"""Get API key from CredentialManager or environment.
|
||||
def _get_api_key_from_credential_store() -> str | None:
|
||||
"""Get API key from CredentialStoreAdapter or environment.
|
||||
|
||||
Priority:
|
||||
1. CredentialManager (supports .env hot-reload)
|
||||
1. CredentialStoreAdapter (supports encrypted storage + env vars)
|
||||
2. os.environ fallback
|
||||
"""
|
||||
try:
|
||||
from aden_tools.credentials import CredentialManager
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
creds = CredentialManager()
|
||||
creds = CredentialStoreAdapter.with_env_storage()
|
||||
if creds.is_available("anthropic"):
|
||||
return creds.get("anthropic")
|
||||
except ImportError:
|
||||
@@ -43,14 +44,16 @@ class AnthropicProvider(LLMProvider):
|
||||
Initialize the Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key. If not provided, uses CredentialManager
|
||||
api_key: Anthropic API key. If not provided, uses CredentialStoreAdapter
|
||||
or ANTHROPIC_API_KEY env var.
|
||||
model: Model to use (default: claude-haiku-4-5-20251001)
|
||||
"""
|
||||
# Delegate to LiteLLMProvider internally.
|
||||
self.api_key = api_key or _get_api_key_from_credential_manager()
|
||||
self.api_key = api_key or _get_api_key_from_credential_store()
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key required. Set ANTHROPIC_API_KEY env var or pass api_key.")
|
||||
raise ValueError(
|
||||
"Anthropic API key required. Set ANTHROPIC_API_KEY env var or pass api_key."
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
@@ -65,6 +68,8 @@ class AnthropicProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion from Claude (via LiteLLM)."""
|
||||
return self._provider.complete(
|
||||
@@ -72,6 +77,8 @@ class AnthropicProvider(LLMProvider):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
@@ -79,7 +86,7 @@ class AnthropicProvider(LLMProvider):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: callable,
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until Claude produces a final response (via LiteLLM)."""
|
||||
|
||||
@@ -8,11 +8,78 @@ See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
try:
|
||||
import litellm
|
||||
from litellm.exceptions import RateLimitError
|
||||
except ImportError:
|
||||
litellm = None # type: ignore[assignment]
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
|
||||
|
||||
def _estimate_tokens(model: str, messages: list[dict]) -> tuple[int, str]:
|
||||
"""Estimate token count for messages. Returns (token_count, method)."""
|
||||
# Try litellm's token counter first
|
||||
if litellm is not None:
|
||||
try:
|
||||
count = litellm.token_counter(model=model, messages=messages)
|
||||
return count, "litellm"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: rough estimate based on character count (~4 chars per token)
|
||||
total_chars = sum(len(str(m.get("content", ""))) for m in messages)
|
||||
return total_chars // 4, "estimate"
|
||||
|
||||
|
||||
def _dump_failed_request(
|
||||
model: str,
|
||||
kwargs: dict[str, Any],
|
||||
error_type: str,
|
||||
attempt: int,
|
||||
) -> str:
|
||||
"""Dump failed request to a file for debugging. Returns the file path."""
|
||||
FAILED_REQUESTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{error_type}_{model.replace('/', '_')}_{timestamp}.json"
|
||||
filepath = FAILED_REQUESTS_DIR / filename
|
||||
|
||||
# Build dump data
|
||||
messages = kwargs.get("messages", [])
|
||||
dump_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"error_type": error_type,
|
||||
"attempt": attempt,
|
||||
"estimated_tokens": _estimate_tokens(model, messages),
|
||||
"num_messages": len(messages),
|
||||
"messages": messages,
|
||||
"tools": kwargs.get("tools"),
|
||||
"max_tokens": kwargs.get("max_tokens"),
|
||||
"temperature": kwargs.get("temperature"),
|
||||
}
|
||||
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(dump_data, f, indent=2, default=str)
|
||||
|
||||
return str(filepath)
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
@@ -23,6 +90,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
- OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
|
||||
- Anthropic: claude-3-opus, claude-3-sonnet, claude-3-haiku
|
||||
- Google: gemini-pro, gemini-1.5-pro, gemini-1.5-flash
|
||||
- DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner
|
||||
- Mistral: mistral-large, mistral-medium, mistral-small
|
||||
- Groq: llama3-70b, mixtral-8x7b
|
||||
- Local: ollama/llama3, ollama/mistral
|
||||
@@ -38,6 +106,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Google Gemini
|
||||
provider = LiteLLMProvider(model="gemini/gemini-1.5-flash")
|
||||
|
||||
# DeepSeek
|
||||
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
|
||||
|
||||
# Local Ollama
|
||||
provider = LiteLLMProvider(model="ollama/llama3")
|
||||
|
||||
@@ -72,12 +143,101 @@ class LiteLLMProvider(LLMProvider):
|
||||
self.api_base = api_base
|
||||
self.extra_kwargs = kwargs
|
||||
|
||||
if litellm is None:
|
||||
raise ImportError(
|
||||
"LiteLLM is not installed. Please install it with: pip install litellm"
|
||||
)
|
||||
|
||||
def _completion_with_rate_limit_retry(self, **kwargs: Any) -> Any:
|
||||
"""Call litellm.completion with retry on 429 rate limit errors and empty responses."""
|
||||
model = kwargs.get("model", self.model)
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
try:
|
||||
response = litellm.completion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
# Some providers (e.g. Gemini) return 200 with empty content on
|
||||
# rate limit / quota exhaustion instead of a proper 429. Treat
|
||||
# empty responses the same as a rate-limit error and retry.
|
||||
content = response.choices[0].message.content if response.choices else None
|
||||
has_tool_calls = bool(response.choices and response.choices[0].message.tool_calls)
|
||||
if not content and not has_tool_calls:
|
||||
finish_reason = (
|
||||
response.choices[0].finish_reason if response.choices else "unknown"
|
||||
)
|
||||
# Dump full request to file for debugging
|
||||
messages = kwargs.get("messages", [])
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_response",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[retry] Empty response - {len(messages)} messages, "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
|
||||
if attempt == RATE_LIMIT_MAX_RETRIES:
|
||||
logger.error(
|
||||
f"[retry] GAVE UP on {model} after {RATE_LIMIT_MAX_RETRIES + 1} "
|
||||
f"attempts — empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0})"
|
||||
)
|
||||
return response
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"[retry] {model} returned empty response "
|
||||
f"(finish_reason={finish_reason}, "
|
||||
f"choices={len(response.choices) if response.choices else 0}) — "
|
||||
f"likely rate limited or quota exceeded. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
|
||||
return response
|
||||
except RateLimitError as e:
|
||||
# Dump full request to file for debugging
|
||||
messages = kwargs.get("messages", [])
|
||||
token_count, token_method = _estimate_tokens(model, messages)
|
||||
dump_path = _dump_failed_request(
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
error_type="rate_limit",
|
||||
attempt=attempt,
|
||||
)
|
||||
if attempt == RATE_LIMIT_MAX_RETRIES:
|
||||
logger.error(
|
||||
f"[retry] GAVE UP on {model} after {RATE_LIMIT_MAX_RETRIES + 1} "
|
||||
f"attempts — rate limit error: {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}"
|
||||
)
|
||||
raise
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"[retry] {model} rate limited (429): {e!s}. "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Full request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
time.sleep(wait)
|
||||
# unreachable, but satisfies type checker
|
||||
raise RuntimeError("Exhausted rate limit retries")
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion using LiteLLM."""
|
||||
# Prepare messages with system prompt
|
||||
@@ -86,6 +246,15 @@ class LiteLLMProvider(LLMProvider):
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Add JSON mode via prompt engineering (works across all providers)
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
# Append to system message if present, otherwise add as system message
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
# Build kwargs
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
@@ -103,8 +272,13 @@ class LiteLLMProvider(LLMProvider):
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
# Add response_format for structured output
|
||||
# LiteLLM passes this through to the underlying provider
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
# Make the call
|
||||
response = litellm.completion(**kwargs)
|
||||
response = self._completion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
# Extract content
|
||||
content = response.choices[0].message.content or ""
|
||||
@@ -128,8 +302,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: callable,
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until the LLM produces a final response."""
|
||||
# Prepare messages with system prompt
|
||||
@@ -149,7 +324,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": 1024,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
@@ -159,7 +334,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = litellm.completion(**kwargs)
|
||||
response = self._completion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
# Track tokens
|
||||
usage = response.usage
|
||||
@@ -203,11 +378,18 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
# Surface error to LLM and skip tool execution
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
"""Mock LLM Provider for testing and structural validation without real LLM calls."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
|
||||
class MockLLMProvider(LLMProvider):
|
||||
"""
|
||||
Mock LLM provider for testing agents without making real API calls.
|
||||
|
||||
This provider generates placeholder responses based on the expected output structure,
|
||||
allowing structural validation and graph execution testing without incurring costs
|
||||
or requiring API keys.
|
||||
|
||||
Example:
|
||||
llm = MockLLMProvider()
|
||||
response = llm.complete(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
system="Generate JSON with keys: name, age",
|
||||
json_mode=True
|
||||
)
|
||||
# Returns: {"name": "mock_value", "age": "mock_value"}
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "mock-model"):
|
||||
"""
|
||||
Initialize the mock LLM provider.
|
||||
|
||||
Args:
|
||||
model: Model name to report in responses (default: "mock-model")
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def _extract_output_keys(self, system: str) -> list[str]:
|
||||
"""
|
||||
Extract expected output keys from the system prompt.
|
||||
|
||||
Looks for patterns like:
|
||||
- "output_keys: [key1, key2]"
|
||||
- "keys: key1, key2"
|
||||
- "Generate JSON with keys: key1, key2"
|
||||
|
||||
Args:
|
||||
system: System prompt text
|
||||
|
||||
Returns:
|
||||
List of extracted key names
|
||||
"""
|
||||
keys = []
|
||||
|
||||
# Pattern 1: output_keys: [key1, key2]
|
||||
match = re.search(r"output_keys:\s*\[(.*?)\]", system, re.IGNORECASE)
|
||||
if match:
|
||||
keys_str = match.group(1)
|
||||
keys = [k.strip().strip("\"'") for k in keys_str.split(",")]
|
||||
return keys
|
||||
|
||||
# Pattern 2: "keys: key1, key2" or "Generate JSON with keys: key1, key2"
|
||||
match = re.search(r"(?:keys|with keys):\s*([a-zA-Z0-9_,\s]+)", system, re.IGNORECASE)
|
||||
if match:
|
||||
keys_str = match.group(1)
|
||||
keys = [k.strip() for k in keys_str.split(",") if k.strip()]
|
||||
return keys
|
||||
|
||||
# Pattern 3: Look for JSON schema in system prompt
|
||||
match = re.search(r'\{[^}]*"([a-zA-Z0-9_]+)":\s*', system)
|
||||
if match:
|
||||
# Found at least one key in a JSON-like structure
|
||||
all_matches = re.findall(r'"([a-zA-Z0-9_]+)":\s*', system)
|
||||
if all_matches:
|
||||
return list(set(all_matches))
|
||||
|
||||
return keys
|
||||
|
||||
def _generate_mock_response(
|
||||
self,
|
||||
system: str = "",
|
||||
json_mode: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a mock response based on the system prompt and mode.
|
||||
|
||||
Args:
|
||||
system: System prompt (may contain output key hints)
|
||||
json_mode: If True, generate JSON response
|
||||
|
||||
Returns:
|
||||
Mock response string
|
||||
"""
|
||||
if json_mode:
|
||||
# Try to extract expected keys from system prompt
|
||||
keys = self._extract_output_keys(system)
|
||||
|
||||
if keys:
|
||||
# Generate JSON with the expected keys
|
||||
mock_data = {key: f"mock_{key}_value" for key in keys}
|
||||
return json.dumps(mock_data, indent=2)
|
||||
else:
|
||||
# Fallback: generic mock response
|
||||
return json.dumps({"result": "mock_result_value"}, indent=2)
|
||||
else:
|
||||
# Plain text mock response
|
||||
return "This is a mock response for testing purposes."
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without calling a real LLM.
|
||||
|
||||
Args:
|
||||
messages: Conversation history (ignored in mock mode)
|
||||
system: System prompt (used to extract expected output keys)
|
||||
tools: Available tools (ignored in mock mode)
|
||||
max_tokens: Maximum tokens (ignored in mock mode)
|
||||
response_format: Response format (ignored in mock mode)
|
||||
json_mode: If True, generate JSON response
|
||||
|
||||
Returns:
|
||||
LLMResponse with mock content
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=json_mode)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without tool use.
|
||||
|
||||
In mock mode, we skip tool execution and return a final response immediately.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation (ignored in mock mode)
|
||||
system: System prompt (used to extract expected output keys)
|
||||
tools: Available tools (ignored in mock mode)
|
||||
tool_executor: Tool executor function (ignored in mock mode)
|
||||
max_iterations: Max iterations (ignored in mock mode)
|
||||
|
||||
Returns:
|
||||
LLMResponse with mock content
|
||||
"""
|
||||
# In mock mode, we don't execute tools - just return a final response
|
||||
# Try to generate JSON if the system prompt suggests structured output
|
||||
json_mode = "json" in system.lower() or "output_keys" in system.lower()
|
||||
|
||||
content = self._generate_mock_response(system=system, json_mode=json_mode)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -62,6 +63,8 @@ class LLMProvider(ABC):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
@@ -71,6 +74,11 @@ class LLMProvider(ABC):
|
||||
system: System prompt
|
||||
tools: Available tools for the LLM to use
|
||||
max_tokens: Maximum tokens to generate
|
||||
response_format: Optional structured output format. Use:
|
||||
- {"type": "json_object"} for basic JSON mode
|
||||
- {"type": "json_schema", "json_schema": {"name": "...", "schema": {...}}}
|
||||
for strict JSON schema enforcement
|
||||
json_mode: If True, request structured JSON output from the LLM
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and metadata
|
||||
@@ -83,7 +91,7 @@ class LLMProvider(ABC):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: callable,
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""MCP servers for worker-bee."""
|
||||
|
||||
from framework.mcp.agent_builder_server import mcp as agent_builder_server
|
||||
|
||||
__all__ = ["agent_builder_server"]
|
||||
# Don't auto-import servers to avoid double-import issues when running with -m
|
||||
__all__ = []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -417,8 +417,9 @@ def cmd_list(args: argparse.Namespace) -> int:
|
||||
|
||||
directory = Path(args.directory)
|
||||
if not directory.exists():
|
||||
print(f"Directory not found: {directory}", file=sys.stderr)
|
||||
return 1
|
||||
# FIX: Handle missing directory gracefully on fresh install
|
||||
print(f"No agents found in {directory}")
|
||||
return 0
|
||||
|
||||
agents = []
|
||||
for path in directory.iterdir():
|
||||
@@ -458,7 +459,7 @@ def cmd_list(args: argparse.Namespace) -> int:
|
||||
print(f" {agent['name']}")
|
||||
print(f" Path: {agent['path']}")
|
||||
print(f" Description: {agent['description']}")
|
||||
print(f" Steps: {agent['steps']}, Tools: {agent['tools']}")
|
||||
print(f" Nodes: {agent['nodes']}, Tools: {agent['tools']}")
|
||||
print()
|
||||
|
||||
return 0
|
||||
@@ -652,8 +653,9 @@ def _format_natural_language_to_json(
|
||||
|
||||
session_info = (
|
||||
f'\n\nExisting {main_field}: "{existing_value}"\n\n'
|
||||
"The user is providing ADDITIONAL information. Append this new information "
|
||||
f"to the existing {main_field} to create an enriched, more detailed version."
|
||||
f"The user is providing ADDITIONAL information. Append this new "
|
||||
f"information to the existing {main_field} to create an enriched, "
|
||||
"more detailed version."
|
||||
)
|
||||
|
||||
prompt = f"""You are formatting user input for an agent that requires specific input fields.
|
||||
@@ -664,12 +666,7 @@ Required input fields: {", ".join(input_keys)}{session_info}
|
||||
|
||||
User input: {user_input}
|
||||
|
||||
{
|
||||
"If this is a follow-up message, APPEND the new information to the existing field "
|
||||
"value to create a more complete, detailed version. Do not create new fields."
|
||||
if session_context
|
||||
else ""
|
||||
}
|
||||
{"If this is a follow-up, APPEND new info to the existing field value." if session_context else ""}
|
||||
|
||||
Output ONLY valid JSON, no explanation:"""
|
||||
|
||||
@@ -790,7 +787,9 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
print("\nAgent nodes:")
|
||||
for node in info.nodes:
|
||||
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get("input_keys") else ""
|
||||
outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else ""
|
||||
outputs = (
|
||||
f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else ""
|
||||
)
|
||||
print(f" {node['id']}: {node['name']}{inputs}{outputs}")
|
||||
print(f" {node['description']}")
|
||||
print()
|
||||
@@ -933,7 +932,10 @@ def _select_agent(agents_dir: Path) -> str | None:
|
||||
"""Let user select an agent from available agents."""
|
||||
if not agents_dir.exists():
|
||||
print(f"Directory not found: {agents_dir}", file=sys.stderr)
|
||||
return None
|
||||
# fixes issue #696, creates an exports folder if it does not exist
|
||||
agents_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Created directory: {agents_dir}", file=sys.stderr)
|
||||
# return None
|
||||
|
||||
agents = []
|
||||
for path in agents_dir.iterdir():
|
||||
|
||||
@@ -6,6 +6,7 @@ Supports both STDIO and HTTP transports using the official MCP Python SDK.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -85,10 +86,14 @@ class MCPClient:
|
||||
"""
|
||||
# If we have a persistent loop (for STDIO), use it
|
||||
if self._loop is not None:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result()
|
||||
# Check if loop is running AND not closed
|
||||
if self._loop.is_running() and not self._loop.is_closed():
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result()
|
||||
# else: fall through to the standard approach below
|
||||
# This handles the case when STDIO loop exists but is stopped/closed
|
||||
|
||||
# Otherwise, use the standard approach
|
||||
# Standard approach: handle both sync and async contexts
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
asyncio.get_running_loop()
|
||||
@@ -149,10 +154,12 @@ class MCPClient:
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
# Create server parameters
|
||||
# Always inherit parent environment and merge with any custom env vars
|
||||
merged_env = {**os.environ, **(self.config.env or {})}
|
||||
server_params = StdioServerParameters(
|
||||
command=self.config.command,
|
||||
args=self.config.args,
|
||||
env=self.config.env or None,
|
||||
env=merged_env,
|
||||
cwd=self.config.cwd,
|
||||
)
|
||||
|
||||
@@ -233,7 +240,9 @@ class MCPClient:
|
||||
try:
|
||||
response = self._http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}")
|
||||
logger.info(
|
||||
f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
@@ -256,7 +265,10 @@ class MCPClient:
|
||||
)
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
logger.info(f"Discovered {len(self._tools)} tools from '{self.config.name}': {list(self._tools.keys())}")
|
||||
tool_names = list(self._tools.keys())
|
||||
logger.info(
|
||||
f"Discovered {len(self._tools)} tools from '{self.config.name}': {tool_names}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to discover tools from '{self.config.name}': {e}")
|
||||
raise
|
||||
@@ -392,19 +404,110 @@ class MCPClient:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
|
||||
|
||||
_CLEANUP_TIMEOUT = 10
|
||||
_THREAD_JOIN_TIMEOUT = 12
|
||||
|
||||
async def _cleanup_stdio_async(self) -> None:
|
||||
"""Async cleanup for STDIO session and context managers.
|
||||
|
||||
Cleanup order is critical:
|
||||
- The session must be closed BEFORE the stdio_context because the session
|
||||
depends on the streams provided by stdio_context.
|
||||
- This mirrors the initialization order in _connect_stdio(), where
|
||||
stdio_context is entered first (providing streams), then the session is
|
||||
created with those streams and entered.
|
||||
- Do not change this ordering without carefully considering these dependencies.
|
||||
"""
|
||||
# First: close session (depends on stdio_context streams)
|
||||
try:
|
||||
if self._session:
|
||||
await self._session.__aexit__(None, None, None)
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
"MCP session cleanup was cancelled; proceeding with best-effort shutdown"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing MCP session: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
|
||||
# Second: close stdio_context (provides the underlying streams)
|
||||
try:
|
||||
if self._stdio_context:
|
||||
await self._stdio_context.__aexit__(None, None, None)
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
"STDIO context cleanup was cancelled; proceeding with best-effort shutdown"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing STDIO context: {e}")
|
||||
finally:
|
||||
self._stdio_context = None
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
# Clean up persistent STDIO connection
|
||||
if self._loop is not None:
|
||||
# Stop event loop - this will cause context managers to clean up naturally
|
||||
if self._loop and self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
cleanup_attempted = False
|
||||
|
||||
# Wait for thread to finish
|
||||
# Properly close session and context managers before stopping loop
|
||||
# Note: There's an inherent race condition between checking is_running()
|
||||
# and calling run_coroutine_threadsafe(). We handle this by catching
|
||||
# any exceptions that may occur if the loop stops between these calls.
|
||||
if self._loop.is_running():
|
||||
try:
|
||||
cleanup_future = asyncio.run_coroutine_threadsafe(
|
||||
self._cleanup_stdio_async(), self._loop
|
||||
)
|
||||
cleanup_future.result(timeout=self._CLEANUP_TIMEOUT)
|
||||
cleanup_attempted = True
|
||||
except TimeoutError:
|
||||
# Cleanup took too long - may indicate stuck resources or slow MCP server
|
||||
cleanup_attempted = True
|
||||
logger.warning(f"Async cleanup timed out after {self._CLEANUP_TIMEOUT} seconds")
|
||||
except RuntimeError as e:
|
||||
# Likely: loop stopped between is_running() check and run_coroutine_threadsafe()
|
||||
cleanup_attempted = True
|
||||
logger.debug(f"Event loop stopped during async cleanup: {e}")
|
||||
except Exception as e:
|
||||
# Cleanup was attempted but failed (e.g., error in _cleanup_stdio_async())
|
||||
cleanup_attempted = True
|
||||
logger.warning(f"Error during async cleanup: {e}")
|
||||
|
||||
# Now stop the event loop
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
except RuntimeError:
|
||||
# Loop may have already stopped
|
||||
pass
|
||||
|
||||
if not cleanup_attempted:
|
||||
# Fallback: loop exists but is not running (e.g., crashed or stopped externally).
|
||||
# At this point the loop and associated resources are in an undefined state.
|
||||
# The context managers (_session, _stdio_context) were created in the loop's
|
||||
# thread and may not be safely cleanable from here. Just log and proceed
|
||||
# with reference clearing - the OS will reclaim resources on process exit.
|
||||
logger.warning(
|
||||
"Event loop for STDIO MCP connection exists but is not running; "
|
||||
"skipping async cleanup. Resources may not be fully released."
|
||||
)
|
||||
|
||||
# Wait for thread to finish (timeout proportional to cleanup timeout)
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
self._loop_thread.join(timeout=2)
|
||||
self._loop_thread.join(timeout=self._THREAD_JOIN_TIMEOUT)
|
||||
if self._loop_thread.is_alive():
|
||||
logger.warning(
|
||||
"Event loop thread for STDIO MCP connection did not terminate "
|
||||
f"within {self._THREAD_JOIN_TIMEOUT}s; thread may still be running."
|
||||
)
|
||||
|
||||
# Clear references
|
||||
# Clear remaining references
|
||||
# Note: _session and _stdio_context may already be None if _cleanup_stdio_async()
|
||||
# succeeded. This redundant assignment is intentional for safety in cases where:
|
||||
# 1. Cleanup timed out or failed
|
||||
# 2. Cleanup was skipped (loop not running)
|
||||
# 3. CancelledError interrupted cleanup
|
||||
# Setting None to None is safe and ensures clean state.
|
||||
self._session = None
|
||||
self._stdio_context = None
|
||||
self._read_stream = None
|
||||
|
||||
@@ -400,7 +400,11 @@ class AgentOrchestrator:
|
||||
return await self._llm_route(request, intent, capable)
|
||||
|
||||
# If no capable agents, check uncertain ones
|
||||
uncertain = [(name, cap) for name, cap in capabilities.items() if cap.level == CapabilityLevel.UNCERTAIN]
|
||||
uncertain = [
|
||||
(name, cap)
|
||||
for name, cap in capabilities.items()
|
||||
if cap.level == CapabilityLevel.UNCERTAIN
|
||||
]
|
||||
if uncertain:
|
||||
uncertain.sort(key=lambda x: -x[1].confidence)
|
||||
return RoutingDecision(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user