Merge remote-tracking branch 'origin/main' into conductorchicago

Resolved conflict in tools/pyproject.toml by keeping the expanded format
with sql dependency from main.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Shivam Sharma
2026-01-29 04:19:16 +05:30
206 changed files with 23234 additions and 3400 deletions
+15
View File
@@ -0,0 +1,15 @@
{
"hooks": {
"PostToolUse": [
{
"matcher": "Edit|Write|NotebookEdit",
"hooks": [
{
"type": "command",
"command": "ruff check --fix \"$CLAUDE_FILE_PATH\" 2>/dev/null; ruff format \"$CLAUDE_FILE_PATH\" 2>/dev/null; true"
}
]
}
]
}
}
@@ -108,8 +108,10 @@ async def _interactive_shell(verbose=False):
try:
while True:
try:
topic = await asyncio.get_event_loop().run_in_executor(None, input, "Topic> ")
if topic.lower() in ['quit', 'exit', 'q']:
topic = await asyncio.get_event_loop().run_in_executor(
None, input, "Topic> "
)
if topic.lower() in ["quit", "exit", "q"]:
click.echo("Goodbye!")
break
@@ -130,7 +132,11 @@ async def _interactive_shell(verbose=False):
click.echo(f"\nReport saved to: {output['file_path']}\n")
if "final_report" in output:
click.echo("\n--- Report Preview ---\n")
preview = output["final_report"][:500] + "..." if len(output.get("final_report", "")) > 500 else output.get("final_report", "")
preview = (
output["final_report"][:500] + "..."
if len(output.get("final_report", "")) > 500
else output.get("final_report", "")
)
click.echo(preview)
click.echo("\n")
else:
@@ -142,6 +148,7 @@ async def _interactive_shell(verbose=False):
except Exception as e:
click.echo(f"Error: {e}", err=True)
import traceback
traceback.print_exc()
finally:
await agent.stop()
@@ -1,4 +1,5 @@
"""Agent graph construction for Online Research Agent."""
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
from framework.graph.edge import GraphSpec
from framework.graph.executor import ExecutionResult
@@ -8,6 +9,16 @@ from framework.llm import LiteLLMProvider
from framework.runner.tool_registry import ToolRegistry
from .config import default_config, metadata
from .nodes import (
parse_query_node,
search_sources_node,
fetch_content_node,
evaluate_sources_node,
synthesize_findings_node,
write_report_node,
quality_check_node,
save_report_node,
)
# Goal definition
goal = Goal(
@@ -78,17 +89,6 @@ goal = Goal(
),
],
)
# Import nodes
from .nodes import (
parse_query_node,
search_sources_node,
fetch_content_node,
evaluate_sources_node,
synthesize_findings_node,
write_report_node,
quality_check_node,
save_report_node,
)
# Node list
nodes = [
@@ -195,13 +195,15 @@ class OnlineResearchAgent:
trigger_type = "manual"
name = ep_id.replace("-", " ").title()
specs.append(EntryPointSpec(
id=ep_id,
name=name,
entry_node=node_id,
trigger_type=trigger_type,
isolation_level="shared",
))
specs.append(
EntryPointSpec(
id=ep_id,
name=name,
entry_node=node_id,
trigger_type=trigger_type,
isolation_level="shared",
)
)
return specs
def _create_runtime(self, mock_mode=False) -> AgentRuntime:
@@ -226,7 +228,10 @@ class OnlineResearchAgent:
for server_name, server_config in mcp_servers.items():
server_config["name"] = server_name
# Resolve relative cwd paths
if "cwd" in server_config and not Path(server_config["cwd"]).is_absolute():
if (
"cwd" in server_config
and not Path(server_config["cwd"]).is_absolute()
):
server_config["cwd"] = str(agent_dir / server_config["cwd"])
tool_registry.register_mcp_server(server_config)
@@ -298,7 +303,9 @@ class OnlineResearchAgent:
"""
if self._runtime is None or not self._runtime.is_running:
raise RuntimeError("Agent runtime not started. Call start() first.")
return await self._runtime.trigger(entry_point, input_data, correlation_id, session_state=session_state)
return await self._runtime.trigger(
entry_point, input_data, correlation_id, session_state=session_state
)
async def trigger_and_wait(
self,
@@ -321,9 +328,13 @@ class OnlineResearchAgent:
"""
if self._runtime is None or not self._runtime.is_running:
raise RuntimeError("Agent runtime not started. Call start() first.")
return await self._runtime.trigger_and_wait(entry_point, input_data, timeout, session_state=session_state)
return await self._runtime.trigger_and_wait(
entry_point, input_data, timeout, session_state=session_state
)
async def run(self, context: dict, mock_mode=False, session_state=None) -> ExecutionResult:
async def run(
self, context: dict, mock_mode=False, session_state=None
) -> ExecutionResult:
"""
Run the agent (convenience method for simple single execution).
@@ -342,7 +353,9 @@ class OnlineResearchAgent:
else:
entry_point = "start"
result = await self.trigger_and_wait(entry_point, context, session_state=session_state)
result = await self.trigger_and_wait(
entry_point, context, session_state=session_state
)
return result or ExecutionResult(success=False, error="Execution timeout")
finally:
await self.stop()
@@ -404,7 +417,9 @@ class OnlineResearchAgent:
# Validate entry points
for ep_id, node_id in self.entry_points.items():
if node_id not in node_ids:
errors.append(f"Entry point '{ep_id}' references unknown node '{node_id}'")
errors.append(
f"Entry point '{ep_id}' references unknown node '{node_id}'"
)
return {
"valid": len(errors) == 0,
@@ -1,4 +1,5 @@
"""Runtime configuration."""
from dataclasses import dataclass
@@ -13,6 +14,7 @@ class RuntimeConfig:
default_config = RuntimeConfig()
# Agent metadata
@dataclass
class AgentMetadata:
@@ -1,4 +1,5 @@
"""Node definitions for Online Research Agent."""
from framework.graph import NodeSpec
# Node 1: Parse Query
@@ -10,9 +11,21 @@ parse_query_node = NodeSpec(
input_keys=["topic"],
output_keys=["search_queries", "research_focus", "key_aspects"],
output_schema={
"research_focus": {"type": "string", "required": True, "description": "Brief statement of what we're researching"},
"key_aspects": {"type": "array", "required": True, "description": "List of 3-5 key aspects to investigate"},
"search_queries": {"type": "array", "required": True, "description": "List of 3-5 search queries"},
"research_focus": {
"type": "string",
"required": True,
"description": "Brief statement of what we're researching",
},
"key_aspects": {
"type": "array",
"required": True,
"description": "List of 3-5 key aspects to investigate",
},
"search_queries": {
"type": "array",
"required": True,
"description": "List of 3-5 search queries",
},
},
system_prompt="""\
You are a research query strategist. Given a research topic, analyze it and generate search queries.
@@ -50,8 +63,16 @@ search_sources_node = NodeSpec(
input_keys=["search_queries", "research_focus"],
output_keys=["source_urls", "search_results_summary"],
output_schema={
"source_urls": {"type": "array", "required": True, "description": "List of source URLs found"},
"search_results_summary": {"type": "string", "required": True, "description": "Brief summary of what was found"},
"source_urls": {
"type": "array",
"required": True,
"description": "List of source URLs found",
},
"search_results_summary": {
"type": "string",
"required": True,
"description": "Brief summary of what was found",
},
},
system_prompt="""\
You are a research assistant executing web searches. Use the web_search tool to find sources.
@@ -80,8 +101,16 @@ fetch_content_node = NodeSpec(
input_keys=["source_urls", "research_focus"],
output_keys=["fetched_sources", "fetch_errors"],
output_schema={
"fetched_sources": {"type": "array", "required": True, "description": "List of fetched source objects with url, title, content"},
"fetch_errors": {"type": "array", "required": True, "description": "List of URLs that failed to fetch"},
"fetched_sources": {
"type": "array",
"required": True,
"description": "List of fetched source objects with url, title, content",
},
"fetch_errors": {
"type": "array",
"required": True,
"description": "List of URLs that failed to fetch",
},
},
system_prompt="""\
You are a content fetcher. Use web_scrape tool to retrieve content from URLs.
@@ -113,8 +142,16 @@ evaluate_sources_node = NodeSpec(
input_keys=["fetched_sources", "research_focus", "key_aspects"],
output_keys=["ranked_sources", "source_analysis"],
output_schema={
"ranked_sources": {"type": "array", "required": True, "description": "List of ranked sources with scores"},
"source_analysis": {"type": "string", "required": True, "description": "Overview of source quality and coverage"},
"ranked_sources": {
"type": "array",
"required": True,
"description": "List of ranked sources with scores",
},
"source_analysis": {
"type": "string",
"required": True,
"description": "Overview of source quality and coverage",
},
},
system_prompt="""\
You are a source evaluator. Assess each source for quality and relevance.
@@ -153,9 +190,21 @@ synthesize_findings_node = NodeSpec(
input_keys=["ranked_sources", "research_focus", "key_aspects"],
output_keys=["key_findings", "themes", "source_citations"],
output_schema={
"key_findings": {"type": "array", "required": True, "description": "List of key findings with sources and confidence"},
"themes": {"type": "array", "required": True, "description": "List of themes with descriptions and supporting sources"},
"source_citations": {"type": "object", "required": True, "description": "Map of facts to supporting URLs"},
"key_findings": {
"type": "array",
"required": True,
"description": "List of key findings with sources and confidence",
},
"themes": {
"type": "array",
"required": True,
"description": "List of themes with descriptions and supporting sources",
},
"source_citations": {
"type": "object",
"required": True,
"description": "Map of facts to supporting URLs",
},
},
system_prompt="""\
You are a research synthesizer. Analyze multiple sources to extract insights.
@@ -192,11 +241,25 @@ write_report_node = NodeSpec(
name="Write Report",
description="Generate a narrative report with proper citations",
node_type="llm_generate",
input_keys=["key_findings", "themes", "source_citations", "research_focus", "ranked_sources"],
input_keys=[
"key_findings",
"themes",
"source_citations",
"research_focus",
"ranked_sources",
],
output_keys=["report_content", "references"],
output_schema={
"report_content": {"type": "string", "required": True, "description": "Full markdown report text with citations"},
"references": {"type": "array", "required": True, "description": "List of reference objects with number, url, title"},
"report_content": {
"type": "string",
"required": True,
"description": "Full markdown report text with citations",
},
"references": {
"type": "array",
"required": True,
"description": "List of reference objects with number, url, title",
},
},
system_prompt="""\
You are a research report writer. Create a well-structured narrative report.
@@ -239,9 +302,21 @@ quality_check_node = NodeSpec(
input_keys=["report_content", "references", "source_citations"],
output_keys=["quality_score", "issues", "final_report"],
output_schema={
"quality_score": {"type": "number", "required": True, "description": "Quality score 0-1"},
"issues": {"type": "array", "required": True, "description": "List of issues found and fixed"},
"final_report": {"type": "string", "required": True, "description": "Corrected full report"},
"quality_score": {
"type": "number",
"required": True,
"description": "Quality score 0-1",
},
"issues": {
"type": "array",
"required": True,
"description": "List of issues found and fixed",
},
"final_report": {
"type": "string",
"required": True,
"description": "Corrected full report",
},
},
system_prompt="""\
You are a quality assurance reviewer. Check the research report for issues.
@@ -278,8 +353,16 @@ save_report_node = NodeSpec(
input_keys=["final_report", "references", "research_focus"],
output_keys=["file_path", "save_status"],
output_schema={
"file_path": {"type": "string", "required": True, "description": "Path where report was saved"},
"save_status": {"type": "string", "required": True, "description": "Status of save operation"},
"file_path": {
"type": "string",
"required": True,
"description": "Path where report was saved",
},
"save_status": {
"type": "string",
"required": True,
"description": "Status of save operation",
},
},
system_prompt="""\
You are a file manager. Save the research report to disk.
+145
View File
@@ -0,0 +1,145 @@
# Triage Issue Skill
Analyze a GitHub issue, verify claims against the codebase, and close invalid issues with a technical response.
## Trigger
User provides a GitHub issue URL or number, e.g.:
- `/triage-issue 1970`
- `/triage-issue https://github.com/adenhq/hive/issues/1970`
## Workflow
### Step 1: Fetch Issue Details
```bash
gh issue view <number> --repo adenhq/hive --json title,body,state,labels,author
```
Extract:
- Title
- Body (the claim/bug report)
- Current state
- Labels
- Author
If issue is already closed, inform user and stop.
### Step 2: Analyze the Claim
Read the issue body and identify:
1. **The core claim** - What is the user asserting?
2. **Technical specifics** - File paths, function names, code snippets mentioned
3. **Expected behavior** - What do they think should happen?
4. **Severity claimed** - Security issue? Bug? Feature request?
### Step 3: Investigate the Codebase
For each technical claim:
1. Find the referenced code using Grep/Glob/Read
2. Understand the actual implementation
3. Check if the claim accurately describes the behavior
4. Look for related tests, documentation, or design decisions
### Step 4: Evaluate Validity
Categorize the issue as one of:
| Category | Action |
|----------|--------|
| **Valid Bug** | Do NOT close. Inform user this is a real issue. |
| **Valid Feature Request** | Do NOT close. Suggest labeling appropriately. |
| **Misunderstanding** | Prepare technical explanation for why behavior is correct. |
| **Fundamentally Flawed** | Prepare critique explaining the technical impossibility or design rationale. |
| **Duplicate** | Find the original issue and prepare duplicate notice. |
| **Incomplete** | Prepare request for more information. |
### Step 5: Draft Response
For issues to be closed, draft a response that:
1. **Acknowledges the concern** - Don't be dismissive
2. **Explains the actual behavior** - With code references
3. **Provides technical rationale** - Why it works this way
4. **References industry standards** - If applicable
5. **Offers alternatives** - If there's a better approach for the user
Use this template:
```markdown
## Analysis
[Brief summary of what was investigated]
## Technical Details
[Explanation with code references]
## Why This Is Working As Designed
[Rationale]
## Recommendation
[What the user should do instead, if applicable]
---
*This issue was reviewed and closed by the maintainers.*
```
### Step 6: User Review
Present the draft to the user with:
```
## Issue #<number>: <title>
**Claim:** <summary of claim>
**Finding:** <valid/invalid/misunderstanding/etc>
**Draft Response:**
<the markdown response>
---
Do you want me to post this comment and close the issue?
```
Use AskUserQuestion with options:
- "Post and close" - Post comment, close issue
- "Edit response" - Let user modify the response
- "Skip" - Don't take action
### Step 7: Execute Action
If user approves:
```bash
# Post comment
gh issue comment <number> --repo adenhq/hive --body "<response>"
# Close issue
gh issue close <number> --repo adenhq/hive --reason "not planned"
```
Report success with link to the issue.
## Important Guidelines
1. **Never close valid issues** - If there's any merit to the claim, don't close it
2. **Be respectful** - The reporter took time to file the issue
3. **Be technical** - Provide code references and evidence
4. **Be educational** - Help them understand, don't just dismiss
5. **Check twice** - Make sure you understand the code before declaring something invalid
6. **Consider edge cases** - Maybe their environment reveals a real issue
## Example Critiques
### Security Misunderstanding
> "The claim that secrets are exposed in plaintext misunderstands the encryption architecture. While `SecretStr` is used for logging protection, actual encryption is provided by Fernet (AES-128-CBC) at the storage layer. The code path is: serialize → encrypt → write. Only encrypted bytes touch disk."
### Impossible Request
> "The requested feature would require [X] which violates [fundamental constraint]. This is not a limitation of our implementation but a fundamental property of [technology/protocol]."
### Already Handled
> "This scenario is already handled by [code reference]. The reporter may be using an older version or misconfigured environment."
+20
View File
@@ -0,0 +1,20 @@
{
"mcpServers": {
"agent-builder": {
"command": "python",
"args": ["-m", "framework.mcp.agent_builder_server"],
"cwd": "core",
"env": {
"PYTHONPATH": "../tools/src"
}
},
"tools": {
"command": "python",
"args": ["mcp_server.py", "--stdio"],
"cwd": "tools",
"env": {
"PYTHONPATH": "src"
}
}
}
}
+1
View File
@@ -0,0 +1 @@
../../.claude/skills/agent-workflow
+1
View File
@@ -0,0 +1 @@
../../.claude/skills/building-agents-construction
+1
View File
@@ -0,0 +1 @@
../../.claude/skills/building-agents-core
+1
View File
@@ -0,0 +1 @@
../../.claude/skills/building-agents-patterns
+1
View File
@@ -0,0 +1 @@
../../.claude/skills/testing-agent
+18
View File
@@ -0,0 +1,18 @@
This project uses ruff for Python linting and formatting.
Rules:
- Line length: 100 characters
- Python target: 3.11+
- Use double quotes for strings
- Sort imports with isort (ruff I rules): stdlib, third-party, first-party (framework), local
- Combine as-imports
- Use type hints on all function signatures
- Use `from __future__ import annotations` for modern type syntax
- Raise exceptions with `from` in except blocks (B904)
- No unused imports (F401), no unused variables (F841)
- Prefer list/dict/set comprehensions over map/filter (C4)
Run `make lint` to auto-fix, `make check` to verify without modifying files.
Run `make format` to apply ruff formatting.
The ruff config lives in core/pyproject.toml under [tool.ruff].
+3
View File
@@ -11,6 +11,9 @@ indent_size = 2
insert_final_newline = true
trim_trailing_whitespace = true
[*.py]
indent_size = 4
[*.md]
trim_trailing_whitespace = false
+124
View File
@@ -0,0 +1,124 @@
# Normalize line endings for all text files
* text=auto
# Source code
*.py text diff=python
*.js text
*.ts text
*.jsx text
*.tsx text
*.json text
*.yaml text
*.yml text
*.toml text
*.ini text
*.cfg text
# Shell scripts (must use LF)
*.sh text eol=lf
quickstart.sh text eol=lf
# PowerShell scripts (Windows-friendly)
*.ps1 text eol=lf
*.psm1 text eol=lf
# Windows batch files (must use CRLF)
*.bat text eol=crlf
*.cmd text eol=crlf
# Documentation
*.md text
*.txt text
*.rst text
*.tex text
# Configuration files
.gitignore text
.gitattributes text
.editorconfig text
Dockerfile text
docker-compose.yml text
requirements*.txt text
pyproject.toml text
setup.py text
setup.cfg text
MANIFEST.in text
LICENSE text
README* text
CHANGELOG* text
CONTRIBUTING* text
CODE_OF_CONDUCT* text
# Web files
*.html text
*.css text
*.scss text
*.sass text
# Data files
*.xml text
*.csv text
*.sql text
# Graphics (binary)
*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.ico binary
*.svg binary
*.eps binary
*.bmp binary
*.tif binary
*.tiff binary
# Archives (binary)
*.zip binary
*.tar binary
*.gz binary
*.bz2 binary
*.7z binary
*.rar binary
# Python compiled (binary)
*.pyc binary
*.pyo binary
*.pyd binary
*.whl binary
*.egg binary
# System libraries (binary)
*.so binary
*.dll binary
*.dylib binary
*.lib binary
*.a binary
# Documents (binary)
*.pdf binary
*.doc binary
*.docx binary
*.ppt binary
*.pptx binary
*.xls binary
*.xlsx binary
# Fonts (binary)
*.ttf binary
*.otf binary
*.woff binary
*.woff2 binary
*.eot binary
# Audio/Video (binary)
*.mp3 binary
*.mp4 binary
*.wav binary
*.avi binary
*.mov binary
*.flv binary
# Database files (binary)
*.db binary
*.sqlite binary
*.sqlite3 binary
-1
View File
@@ -8,7 +8,6 @@
/hive/ @adenhq/maintainers
# Infrastructure
/docker-compose*.yml @adenhq/maintainers
/.github/ @adenhq/maintainers
# Documentation
@@ -0,0 +1,31 @@
name: Auto-close duplicate issues
description: Auto-closes issues that are duplicates of existing issues
on:
schedule:
- cron: "0 */6 * * *"
workflow_dispatch:
jobs:
auto-close-duplicates:
runs-on: ubuntu-latest
timeout-minutes: 10
permissions:
contents: read
issues: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Bun
uses: oven-sh/setup-bun@v2
with:
bun-version: latest
- name: Auto-close duplicate issues
run: bun run scripts/auto-close-duplicates.ts
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
STATSIG_API_KEY: ${{ secrets.STATSIG_API_KEY }}
+31 -4
View File
@@ -29,10 +29,15 @@ jobs:
pip install -e .
pip install -r requirements-dev.txt
- name: Run ruff
- name: Ruff lint
run: |
cd core
ruff check .
ruff check core/
ruff check tools/
- name: Ruff format
run: |
ruff format --check core/
ruff format --check tools/
test:
name: Test Python Framework
@@ -79,9 +84,31 @@ jobs:
- name: Validate exported agents
run: |
# Check that agent exports have valid structure
for agent_dir in exports/*/; do
if [ ! -d "exports" ]; then
echo "No exports/ directory found, skipping validation"
exit 0
fi
shopt -s nullglob
agent_dirs=(exports/*/)
shopt -u nullglob
if [ ${#agent_dirs[@]} -eq 0 ]; then
echo "No agent directories in exports/, skipping validation"
exit 0
fi
validated=0
for agent_dir in "${agent_dirs[@]}"; do
if [ -f "$agent_dir/agent.json" ]; then
echo "Validating $agent_dir"
python -c "import json; json.load(open('$agent_dir/agent.json'))"
validated=$((validated + 1))
fi
done
if [ "$validated" -eq 0 ]; then
echo "No agent.json files found in exports/, skipping validation"
else
echo "Validated $validated agent(s)"
fi
+64 -35
View File
@@ -1,4 +1,4 @@
name: Claude Issue Triage
name: Issue Triage
on:
issues:
@@ -7,9 +7,11 @@ on:
jobs:
triage:
runs-on: ubuntu-latest
timeout-minutes: 10
permissions:
contents: read
issues: write
id-token: write
steps:
- name: Checkout repository
@@ -17,52 +19,79 @@ jobs:
with:
fetch-depth: 1
- name: Run Claude Issue Triage
id: claude-triage
- name: Triage and check for duplicates
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
allowed_non_write_users: '*'
allowed_non_write_users: "*"
prompt: |
REPO: ${{ github.repository }}
ISSUE NUMBER: ${{ github.event.issue.number }}
TITLE: ${{ github.event.issue.title }}
BODY: ${{ github.event.issue.body }}
AUTHOR: ${{ github.event.issue.user.login }}
Analyze this new issue and perform triage tasks.
Analyze this new issue and perform the following:
Issue: #${{ github.event.issue.number }}
Repository: ${{ github.repository }}
1. **Categorize the issue type using ONLY these labels:**
- bug: Something isn't working
- enhancement: New feature or request
- question: Further information is requested
- documentation: Improvements or additions to documentation
- good first issue: Good for newcomers (if issue is well-defined and small scope)
- help wanted: Extra attention is needed (if issue needs community input)
- backlog: Tracked for the future, but not currently planned or prioritized
## Your Tasks:
2. **Check for duplicates:**
Search for similar existing issues using:
`gh issue list --state all --search "<key terms from title/body>"`
### 1. Get issue details
Use mcp__github__get_issue to get the full details of issue #${{ github.event.issue.number }}
If a duplicate exists:
- Add the "duplicate" label
- Comment mentioning the original issue number
### 2. Check for duplicates
Search for similar existing issues using mcp__github__search_issues with relevant keywords from the issue title and body.
3. **Check for invalid issues:**
If the issue lacks sufficient information, is spam, or doesn't make sense:
- Add the "invalid" label
- Comment asking for clarification or explaining why it's invalid
Criteria for duplicates:
- Same bug or error being reported
- Same feature request (even if worded differently)
- Same question being asked
- Issues describing the same root problem
4. **Apply labels:**
Based on your analysis, add appropriate labels using:
`gh issue edit ${{ github.event.issue.number }} --add-label "label1,label2"`
If you find a duplicate:
- Add a comment using EXACTLY this format (required for auto-close to work):
"Found a possible duplicate of #<issue_number>: <brief explanation of why it's a duplicate>"
- Do NOT apply the "duplicate" label yet (the auto-close script will add it after 12 hours if no objections)
- Suggest the user react with a thumbs-down if they disagree
You may apply multiple labels if appropriate (e.g., "bug,help wanted").
### 3. Check for Low-Quality / AI Spam
Analyze the issue quality. We are receiving many low-effort, AI-generated spam issues.
Flag the issue as INVALID if it matches these criteria:
- **Vague/Generic**: Title is "Fix bug" or "Error" without specific context.
- **Hallucinated**: Refers to files or features that do not exist in this repo.
- **Template Filler**: Body contains "Insert description here" or unrelated gibberish.
- **Low Effort**: No reproduction steps, no logs, only 1-2 sentences.
5. **Add a brief comment** summarizing your triage decision to help maintainers.
If identified as spam/low-quality:
- Add the "invalid" label.
- Add a comment:
"This issue has been automatically flagged as low-quality or potentially AI-generated spam. It lacks specific details (logs, reproduction steps, file references) required for us to help. Please open a new issue following the template exactly if this is a legitimate request."
- Do NOT proceed to other steps.
### 4. Check for invalid issues (General)
If the issue is not spam but still lacks information:
- Add the "invalid" label
- Comment asking for clarification
### 5. Categorize with labels (if NOT a duplicate or spam)
Apply appropriate labels based on the issue content. Use ONLY these labels:
- bug: Something isn't working
- enhancement: New feature or request
- question: Further information is requested
- documentation: Improvements or additions to documentation
- good first issue: Good for newcomers (if issue is well-defined and small scope)
- help wanted: Extra attention is needed (if issue needs community input)
- backlog: Tracked for the future, but not currently planned or prioritized
You may apply multiple labels if appropriate (e.g., "bug" and "help wanted").
## Tools Available:
- mcp__github__get_issue: Get issue details
- mcp__github__search_issues: Search for similar issues
- mcp__github__list_issues: List recent issues if needed
- mcp__github__add_issue_comment: Add a comment
- mcp__github__update_issue: Add labels
- mcp__github__get_issue_comments: Get existing comments
Be thorough but efficient. Focus on accurate categorization and finding true duplicates.
claude_args: |
--model claude-3-5-haiku-20241022
--allowedTools "Bash(gh issue:*),Bash(gh search:*)"
--model claude-haiku-4-5-20251001
--allowedTools "mcp__github__get_issue,mcp__github__search_issues,mcp__github__list_issues,mcp__github__add_issue_comment,mcp__github__update_issue,mcp__github__get_issue_comments"
+204
View File
@@ -0,0 +1,204 @@
name: PR Check Command
on:
issue_comment:
types: [created]
jobs:
check-pr:
# Only run on PR comments that start with /check
if: github.event.issue.pull_request && startsWith(github.event.comment.body, '/check')
runs-on: ubuntu-latest
permissions:
pull-requests: write
issues: write
checks: write
statuses: write
steps:
- name: Check PR requirements
uses: actions/github-script@v7
with:
script: |
const prNumber = context.payload.issue.number;
console.log(`Triggered by /check comment on PR #${prNumber}`);
// Fetch PR data
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
});
const prBody = pr.body || '';
const prTitle = pr.title || '';
const prAuthor = pr.user.login;
const headSha = pr.head.sha;
// Create a check run in progress
const { data: checkRun } = await github.rest.checks.create({
owner: context.repo.owner,
repo: context.repo.repo,
name: 'check-requirements',
head_sha: headSha,
status: 'in_progress',
started_at: new Date().toISOString(),
});
// Extract issue numbers
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
const allText = `${prTitle} ${prBody}`;
const matches = [...allText.matchAll(issuePattern)];
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
console.log(`PR #${prNumber}:`);
console.log(` Author: ${prAuthor}`);
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
if (issueNumbers.length === 0) {
const message = `## PR Closed - Requirements Not Met
This PR has been automatically closed because it doesn't meet the requirements.
**Missing:** No linked issue found.
**To fix:**
1. Create or find an existing issue for this work
2. Assign yourself to the issue
3. Re-open this PR and add \`Fixes #123\` in the description
**Why is this required?** See #472 for details.`;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: message,
});
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed',
});
// Update check run to failure
await github.rest.checks.update({
owner: context.repo.owner,
repo: context.repo.repo,
check_run_id: checkRun.id,
status: 'completed',
conclusion: 'failure',
completed_at: new Date().toISOString(),
output: {
title: 'Missing linked issue',
summary: 'PR must reference an issue (e.g., `Fixes #123`)',
},
});
core.setFailed('PR must reference an issue');
return;
}
// Check if PR author is assigned to any linked issue
let issueWithAuthorAssigned = null;
let issuesWithoutAuthor = [];
for (const issueNum of issueNumbers) {
try {
const { data: issue } = await github.rest.issues.get({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issueNum,
});
const assigneeLogins = (issue.assignees || []).map(a => a.login);
if (assigneeLogins.includes(prAuthor)) {
issueWithAuthorAssigned = issueNum;
console.log(` Issue #${issueNum} has PR author ${prAuthor} as assignee`);
break;
} else {
issuesWithoutAuthor.push({
number: issueNum,
assignees: assigneeLogins
});
console.log(` Issue #${issueNum} assignees: ${assigneeLogins.length > 0 ? assigneeLogins.join(', ') : 'none'}`);
}
} catch (error) {
console.log(` Issue #${issueNum} not found`);
}
}
if (!issueWithAuthorAssigned) {
const issueList = issuesWithoutAuthor.map(i =>
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
).join(', ');
const message = `## PR Closed - Requirements Not Met
This PR has been automatically closed because it doesn't meet the requirements.
**PR Author:** @${prAuthor}
**Found issues:** ${issueList}
**Problem:** The PR author must be assigned to the linked issue.
**To fix:**
1. Assign yourself (@${prAuthor}) to one of the linked issues
2. Re-open this PR
**Why is this required?** See #472 for details.`;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: message,
});
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed',
});
// Update check run to failure
await github.rest.checks.update({
owner: context.repo.owner,
repo: context.repo.repo,
check_run_id: checkRun.id,
status: 'completed',
conclusion: 'failure',
completed_at: new Date().toISOString(),
output: {
title: 'PR author not assigned to issue',
summary: `PR author @${prAuthor} must be assigned to one of the linked issues: ${issueList}`,
},
});
core.setFailed('PR author must be assigned to the linked issue');
} else {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `✅ PR requirements met! Issue #${issueWithAuthorAssigned} has @${prAuthor} as assignee.`,
});
// Update check run to success
await github.rest.checks.update({
owner: context.repo.owner,
repo: context.repo.repo,
check_run_id: checkRun.id,
status: 'completed',
conclusion: 'success',
completed_at: new Date().toISOString(),
output: {
title: 'Requirements met',
summary: `Issue #${issueWithAuthorAssigned} has @${prAuthor} as assignee.`,
},
});
console.log(`PR requirements met!`);
}
@@ -0,0 +1,138 @@
name: PR Requirements Backfill
on:
workflow_dispatch:
jobs:
check-all-open-prs:
runs-on: ubuntu-latest
permissions:
pull-requests: write
issues: write
steps:
- name: Check all open PRs
uses: actions/github-script@v7
with:
script: |
const { data: pullRequests } = await github.rest.pulls.list({
owner: context.repo.owner,
repo: context.repo.repo,
state: 'open',
per_page: 100,
});
console.log(`Found ${pullRequests.length} open PRs`);
for (const pr of pullRequests) {
const prNumber = pr.number;
const prBody = pr.body || '';
const prTitle = pr.title || '';
const prAuthor = pr.user.login;
console.log(`\nChecking PR #${prNumber}: ${prTitle}`);
// Extract issue numbers from body and title
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
const allText = `${prTitle} ${prBody}`;
const matches = [...allText.matchAll(issuePattern)];
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
if (issueNumbers.length === 0) {
console.log(` ❌ No linked issue - closing PR`);
const message = `## PR Closed - Requirements Not Met
This PR has been automatically closed because it doesn't meet the requirements.
**Missing:** No linked issue found.
**To fix:**
1. Create or find an existing issue for this work
2. Assign yourself to the issue
3. Re-open this PR and add \`Fixes #123\` in the description`;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: message,
});
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed',
});
continue;
}
// Check if any linked issue has the PR author as assignee
let issueWithAuthorAssigned = null;
let issuesWithoutAuthor = [];
for (const issueNum of issueNumbers) {
try {
const { data: issue } = await github.rest.issues.get({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issueNum,
});
const assigneeLogins = (issue.assignees || []).map(a => a.login);
if (assigneeLogins.includes(prAuthor)) {
issueWithAuthorAssigned = issueNum;
break;
} else {
issuesWithoutAuthor.push({
number: issueNum,
assignees: assigneeLogins
});
}
} catch (error) {
console.log(` Issue #${issueNum} not found or inaccessible`);
}
}
if (!issueWithAuthorAssigned) {
const issueList = issuesWithoutAuthor.map(i =>
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
).join(', ');
console.log(` ❌ PR author not assigned to any linked issue - closing PR`);
const message = `## PR Closed - Requirements Not Met
This PR has been automatically closed because it doesn't meet the requirements.
**PR Author:** @${prAuthor}
**Found issues:** ${issueList}
**Problem:** The PR author must be assigned to the linked issue.
**To fix:**
1. Assign yourself (@${prAuthor}) to one of the linked issues
2. Re-open this PR`;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: message,
});
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed',
});
} else {
console.log(` ✅ PR requirements met! Issue #${issueWithAuthorAssigned} has ${prAuthor} as assignee.`);
}
}
console.log('\nBackfill complete!');
+189
View File
@@ -0,0 +1,189 @@
name: PR Requirements Check
on:
pull_request_target:
types: [opened, reopened, edited, synchronize]
jobs:
check-requirements:
runs-on: ubuntu-latest
permissions:
pull-requests: write
issues: write
steps:
- name: Check PR has linked issue with assignee
uses: actions/github-script@v7
with:
script: |
const pr = context.payload.pull_request;
const prNumber = pr.number;
const prBody = pr.body || '';
const prTitle = pr.title || '';
const prLabels = (pr.labels || []).map(l => l.name);
// Allow micro-fix and documentation PRs without a linked issue
const isMicroFix = prLabels.includes('micro-fix') || /micro-fix/i.test(prTitle);
const isDocumentation = prLabels.includes('documentation') || /\bdocs?\b/i.test(prTitle);
if (isMicroFix || isDocumentation) {
const reason = isMicroFix ? 'micro-fix' : 'documentation';
console.log(`PR #${prNumber} is a ${reason}, skipping issue requirement.`);
return;
}
// Extract issue numbers from body and title
// Matches: fixes #123, closes #123, resolves #123, or plain #123
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
const allText = `${prTitle} ${prBody}`;
const matches = [...allText.matchAll(issuePattern)];
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
console.log(`PR #${prNumber}:`);
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
if (issueNumbers.length === 0) {
const message = `## PR Closed - Requirements Not Met
This PR has been automatically closed because it doesn't meet the requirements.
**Missing:** No linked issue found.
**To fix:**
1. Create or find an existing issue for this work
2. Assign yourself to the issue
3. Re-open this PR and add \`Fixes #123\` in the description
**Exception:** To bypass this requirement, you can:
- Add the \`micro-fix\` label or include \`micro-fix\` in your PR title for trivial fixes
- Add the \`documentation\` label or include \`doc\`/\`docs\` in your PR title for documentation changes
**Micro-fix requirements** (must meet ALL):
| Qualifies | Disqualifies |
|-----------|--------------|
| < 20 lines changed | Any functional bug fix |
| Typos & Documentation & Linting | Refactoring for "clean code" |
| No logic/API/DB changes | New features (even tiny ones) |
**Why is this required?** See #472 for details.`;
const comments = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
});
const botComment = comments.data.find(
(c) => c.user.type === 'Bot' && c.body.includes('PR Closed - Requirements Not Met')
);
if (!botComment) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: message,
});
}
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed',
});
core.setFailed('PR must reference an issue');
return;
}
// Check if any linked issue has the PR author as assignee
const prAuthor = pr.user.login;
let issueWithAuthorAssigned = null;
let issuesWithoutAuthor = [];
for (const issueNum of issueNumbers) {
try {
const { data: issue } = await github.rest.issues.get({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issueNum,
});
const assigneeLogins = (issue.assignees || []).map(a => a.login);
if (assigneeLogins.includes(prAuthor)) {
issueWithAuthorAssigned = issueNum;
console.log(` Issue #${issueNum} has PR author ${prAuthor} as assignee`);
break;
} else {
issuesWithoutAuthor.push({
number: issueNum,
assignees: assigneeLogins
});
console.log(` Issue #${issueNum} assignees: ${assigneeLogins.length > 0 ? assigneeLogins.join(', ') : 'none'} (PR author: ${prAuthor})`);
}
} catch (error) {
console.log(` Issue #${issueNum} not found or inaccessible`);
}
}
if (!issueWithAuthorAssigned) {
const issueList = issuesWithoutAuthor.map(i =>
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
).join(', ');
const message = `## PR Closed - Requirements Not Met
This PR has been automatically closed because it doesn't meet the requirements.
**PR Author:** @${prAuthor}
**Found issues:** ${issueList}
**Problem:** The PR author must be assigned to the linked issue.
**To fix:**
1. Assign yourself (@${prAuthor}) to one of the linked issues
2. Re-open this PR
**Exception:** To bypass this requirement, you can:
- Add the \`micro-fix\` label or include \`micro-fix\` in your PR title for trivial fixes
- Add the \`documentation\` label or include \`doc\`/\`docs\` in your PR title for documentation changes
**Micro-fix requirements** (must meet ALL):
| Qualifies | Disqualifies |
|-----------|--------------|
| < 20 lines changed | Any functional bug fix |
| Typos & Documentation & Linting | Refactoring for "clean code" |
| No logic/API/DB changes | New features (even tiny ones) |
**Why is this required?** See #472 for details.`;
const comments = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
});
const botComment = comments.data.find(
(c) => c.user.type === 'Bot' && c.body.includes('PR Closed - Requirements Not Met')
);
if (!botComment) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: message,
});
}
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed',
});
core.setFailed('PR author must be assigned to the linked issue');
} else {
console.log(`PR requirements met! Issue #${issueWithAuthorAssigned} has ${prAuthor} as assignee.`);
}
+1 -1
View File
@@ -69,4 +69,4 @@ exports/*
.agent-builder-sessions/*
.venv
.venv
+18
View File
@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.6
hooks:
- id: ruff
name: ruff lint (core)
args: [--fix]
files: ^core/
- id: ruff
name: ruff lint (tools)
args: [--fix]
files: ^tools/
- id: ruff-format
name: ruff format (core)
files: ^core/
- id: ruff-format
name: ruff format (tools)
files: ^tools/
+7
View File
@@ -0,0 +1,7 @@
{
"recommendations": [
"charliermarsh.ruff",
"editorconfig.editorconfig",
"ms-python.python"
]
}
+2 -1
View File
@@ -25,8 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed
- N/A
### Fixed
- N/A
- tools: Fixed web_scrape tool attempting to parse non-HTML content (PDF, JSON) as HTML (#487)
### Security
- N/A
+54 -6
View File
@@ -6,6 +6,41 @@ Thank you for your interest in contributing to the Aden Agent Framework! This do
By participating in this project, you agree to abide by our [Code of Conduct](CODE_OF_CONDUCT.md).
## Contributor License Agreement
By submitting a Pull Request, you agree that your contributions will be licensed under the Aden Agent Framework license.
## Issue Assignment Policy
To prevent duplicate work and respect contributors' time, we require issue assignment before submitting PRs.
### How to Claim an Issue
1. **Find an Issue:** Browse existing issues or create a new one
2. **Claim It:** Leave a comment (e.g., *"I'd like to work on this!"*)
3. **Wait for Assignment:** A maintainer will assign you within 24 hours
4. **Submit Your PR:** Once assigned, you're ready to contribute
> **Note:** PRs for unassigned issues may be delayed or closed if someone else was already assigned.
### The 5-Day Momentum Rule
To keep the project moving, issues with **no activity for 5 days** (no PR or status update) will be unassigned. If you need more time, just drop a quick comment!
### Exceptions (No Assignment Needed)
You may submit PRs without prior assignment for:
- **Documentation:** Fixing typos or clarifying instructions — add the `documentation` label or include `doc`/`docs` in your PR title to bypass the linked issue requirement
- **Micro-fixes:** Add the `micro-fix` label or include `micro-fix` in your PR title to bypass the linked issue requirement. Micro-fixes must meet **all** qualification criteria:
| Qualifies | Disqualifies |
|-----------|--------------|
| < 20 lines changed | Any functional bug fix |
| Typos & Documentation & Linting | Refactoring for "clean code" |
| No logic/API/DB changes | New features (even tiny ones) |
If a high-quality PR is submitted for a "stale" assigned issue (no activity for 7+ days), we may proceed with the submitted code.
## Getting Started
1. Fork the repository
@@ -29,6 +64,12 @@ python -c "import framework; import aden_tools; print('✓ Setup complete')"
./quickstart.sh
```
> **Windows Users:**
> If you are on native Windows, it is recommended to use **WSL (Windows Subsystem for Linux)**.
> Alternatively, make sure to run PowerShell or Git Bash with Python 3.11+ installed, and disable "App Execution Aliases" in Windows settings.
> **Tip:** Installing Claude Code skills is optional for running existing agents, but required if you plan to **build new agents**.
## Commit Convention
We follow [Conventional Commits](https://www.conventionalcommits.org/):
@@ -59,11 +100,12 @@ docs(readme): update installation instructions
## Pull Request Process
1. Update documentation if needed
2. Add tests for new functionality
3. Ensure all tests pass
4. Update the CHANGELOG.md if applicable
5. Request review from maintainers
1. **Get assigned to the issue first** (see [Issue Assignment Policy](#issue-assignment-policy))
2. Update documentation if needed
3. Add tests for new functionality
4. Ensure all tests pass
5. Update the CHANGELOG.md if applicable
6. Request review from maintainers
### PR Title Format
@@ -92,6 +134,12 @@ feat(component): add new feature description
## Testing
> **Note:** When testing agents in `exports/`, always set PYTHONPATH:
>
> ```bash
> PYTHONPATH=core:exports python -m agent_name test
> ```
```bash
# Run all tests for the framework
cd core && python -m pytest
@@ -107,4 +155,4 @@ PYTHONPATH=core:exports python -m agent_name test
Feel free to open an issue for questions or join our [Discord community](https://discord.com/invite/MXE49hrKDk).
Thank you for contributing!
Thank you for contributing!
+24 -65
View File
@@ -99,10 +99,13 @@ Get API keys:
./quickstart.sh
```
This installs:
This installs agent-related Claude Code skills:
- `/building-agents` - Build new goal-driven agents
- `/testing-agent` - Test agents with evaluation framework
- `/building-agents-core` - Fundamental agent concepts
- `/building-agents-construction` - Step-by-step agent building
- `/building-agents-patterns` - Best practices and design patterns
- `/testing-agent` - Test and validate agents
- `/agent-workflow` - End-to-end guided workflow
### Verify Setup
@@ -132,15 +135,22 @@ hive/ # Repository root
│ └── CODEOWNERS # Auto-assign reviewers
├── .claude/ # Claude Code Skills
│ └── skills/
│ ├── building-agents/ # Skills for building agents
├── SKILL.md # Main skill definition
── building-agents-core/
├── building-agents-patterns/
── building-agents-construction/
│ └── skills/ # Skills for building
│ ├── building-agents-core/
| | ├── SKILL.md # Main skill definition
| ── examples
│ ├── building-agents-patterns/
| | ── SKILL.md
│ | └── examples
│ ├── building-agents-construction/
| | ├── SKILL.md
│ | └── examples
│ ├── testing-agent/ # Skills for testing agents
│ │ ── SKILL.md
└── agent-workflow/ # Complete workflow orchestration
│ │ ── SKILL.md
| └── examples
│ └── agent-workflow/ # Complete workflow
| ├── SKILL.md
│ └── examples
├── core/ # CORE FRAMEWORK PACKAGE
│ ├── framework/ # Main package code
@@ -213,7 +223,7 @@ The fastest way to build agents is using the Claude Code skills:
./quickstart.sh
# Build a new agent
claude> /building-agents
claude> /building-agents-construction
# Test the agent
claude> /testing-agent
@@ -224,7 +234,7 @@ claude> /testing-agent
1. **Define Your Goal**
```
claude> /building-agents
claude> /building-agents-construction
Enter goal: "Build an agent that processes customer support tickets"
```
@@ -511,30 +521,7 @@ chore(deps): update React to 18.2.0
## Debugging
### Frontend Debugging
**React Developer Tools:**
1. Install the [React DevTools browser extension](https://react.dev/learn/react-developer-tools)
2. Open browser DevTools → React tab
3. Inspect component tree, props, state, and hooks
**VS Code Debugging:**
1. Add Chrome debug configuration to `.vscode/launch.json`:
```json
{
"type": "chrome",
"request": "launch",
"name": "Debug Frontend",
"url": "http://localhost:3000",
"webRoot": "${workspaceFolder}/honeycomb/src"
}
```
2. Start the dev server: `npm run dev -w honeycomb`
3. Press F5 in VS Code
### Backend Debugging
@@ -594,7 +581,7 @@ pip install -e .
```bash
# Option 1: Use Claude Code skill (recommended)
claude> /building-agents
claude> /building-agents-construction
# Option 2: Create manually
# Note: exports/ is initially empty (gitignored). Create your agent directory:
@@ -709,14 +696,6 @@ kill -9 <PID>
# Or change ports in config.yaml and regenerate
```
### Node Modules Issues
```bash
# Clean everything and reinstall
npm run clean
rm -rf node_modules package-lock.json
npm install
```
### Docker Issues
@@ -728,15 +707,7 @@ docker compose build --no-cache
docker compose up
```
### TypeScript Errors After Pull
```bash
# Rebuild TypeScript
npm run build
# Or restart TS server in VS Code
# Cmd/Ctrl + Shift + P → "TypeScript: Restart TS Server"
```
### Environment Variables Not Loading
@@ -746,24 +717,12 @@ npm run generate:env
# Verify files exist
cat .env
cat honeycomb/.env
cat hive/.env
# Restart dev servers after changing env
```
### Tests Failing
```bash
# Run with verbose output
npm run test -w honeycomb -- --reporter=verbose
# Run single test file
npm run test -w honeycomb -- src/components/Button.test.tsx
# Clear test cache
npm run test -w honeycomb -- --clearCache
```
---
+122 -13
View File
@@ -9,6 +9,10 @@ Complete setup guide for building and running goal-driven agents with the Aden A
./scripts/setup-python.sh
```
> **Note for Windows Users:**
> Running the setup script on native Windows shells (PowerShell / Git Bash) may sometimes fail due to Python App Execution Aliases.
> It is **strongly recommended to use WSL (Windows Subsystem for Linux)** for a smoother setup experience.
This will:
- Check Python version (requires 3.11+)
@@ -17,6 +21,26 @@ This will:
- Fix package compatibility issues (openai + litellm)
- Verify all installations
## Alpine Linux Setup
If you are using Alpine Linux (e.g., inside a Docker container), you must install system dependencies and use a virtual environment before running the setup script:
1. Install System Dependencies:
```bash
apk update
apk add bash git python3 py3-pip nodejs npm curl build-base python3-dev linux-headers libffi-dev
```
2. Set up Virtual Environment (Required for Python 3.12+):
```
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip setuptools wheel
```
3. Run the Quickstart Script:
```
./quickstart.sh
```
## Manual Setup (Alternative)
If you prefer to set up manually or the script fails:
@@ -50,6 +74,9 @@ python -c "import aden_tools; print('✓ aden_tools OK')"
python -c "import litellm; print('✓ litellm OK')"
```
> **Windows Tip:**
> On Windows, if the verification commands fail, ensure you are running them in **WSL** or after **disabling Python App Execution Aliases** in Windows Settings → Apps → App Execution Aliases.
## Requirements
### Python Version
@@ -63,6 +90,7 @@ python -c "import litellm; print('✓ litellm OK')"
- pip (latest version)
- 2GB+ RAM
- Internet connection (for LLM API calls)
- For Windows users: WSL 2 is recommended for full compatibility.
### API Keys (Optional)
@@ -114,44 +142,125 @@ PYTHONPATH=core:exports python -m outbound_sales_agent validate
PYTHONPATH=core:exports python -m personal_assistant_agent run --input '{...}'
```
## Building New Agents
## Building New Agents and Run Flow
Use Claude Code CLI with the agent building skills:
Build and run an agent using Claude Code CLI with the agent building skills:
### 1. Install Skills (One-time)
### 1. Install Claude Skills (One-time)
```bash
./quickstart.sh
```
This installs:
This installs agent-related Claude Code skills:
- `/building-agents` - Build new agents
- `/testing-agent` - Test agents
- `/building-agents-construction` - Step-by-step build guide
- `/building-agents-core` - Fundamental concepts
- `/building-agents-patterns` - Best practices
- `/testing-agent` - Test and validate agents
- `/agent-workflow` - Complete workflow
### 2. Build an Agent
```
claude> /building-agents
claude> /building-agents-construction
```
Follow the prompts to:
1. Define your agent's goal
2. Design the workflow nodes
3. Connect edges
4. Generate the agent package
3. Connect nodes with edges
4. Generate the agent package under `exports/`
### 3. Test Your Agent
This step creates the initial agent structure required for further development.
### 3. Define Agent Logic
```
claude> /building-agents-core
```
Follow the prompts to:
1. Understand the agent architecture and file structure
2. Define the agent's goal, success criteria, and constraints
3. Learn node types (LLM, tool-use, router, function)
4. Discover and validate available tools before use
This step establishes the core concepts and rules needed before building an agent.
### 4. Apply Agent Patterns
```
claude> /building-agents-patterns
```
Follow the prompts to:
1. Apply best-practice agent design patterns
2. Add pause/resume flows for multi-turn interactions
3. Improve robustness with routing, fallbacks, and retries
4. Avoid common anti-patterns during agent construction
This step helps optimize agent design before final testing.
### 5. Test Your Agent
```
claude> /testing-agent
```
Follow the prompts to:
Creates comprehensive test suites for your agent.
1. Generate test guidelines for constraints and success criteria
2. Write agent tests directly under `exports/{agent}/tests/`
3. Run goal-based evaluation tests
4. Debug failing tests and iterate on agent improvements
This step verifies that the agent meets its goals before production use.
### 6. Agent Development Workflow (End-to-End)
```
claude> /agent-workflow
```
Follow the guided flow to:
1. Understand core agent concepts (optional)
2. Build the agent structure step by step
3. Apply best-practice design patterns (optional)
4. Test and validate the agent against its goals
This workflow orchestrates all agent-building skills to take you from idea → production-ready agent.
## Troubleshooting
### "externally-managed-environment" error (PEP 668)
**Cause:** Python 3.12+ on macOS/Homebrew, WSL, or some Linux distros prevents system-wide pip installs.
**Solution:** Create and use a virtual environment:
```bash
# Create virtual environment
python3 -m venv .venv
# Activate it
source .venv/bin/activate # macOS/Linux
# .venv\Scripts\activate # Windows
# Then run setup
./scripts/setup-python.sh
```
Always activate the venv before running agents:
```bash
source .venv/bin/activate
PYTHONPATH=core:exports python -m your_agent_name demo
```
### "ModuleNotFoundError: No module named 'framework'"
**Solution:** Install the core package:
@@ -188,7 +297,7 @@ pip install --upgrade "openai>=1.0.0"
**Cause:** Not running from project root or missing PYTHONPATH
**Solution:** Ensure you're in `/home/timothy/oss/hive/` and use:
**Solution:** Ensure you're in the project root directory and use:
```bash
PYTHONPATH=core:exports python -m support_ticket_agent validate
@@ -256,7 +365,7 @@ This design allows agents in `exports/` to be:
### 2. Build Agent (Claude Code)
```
claude> /building-agents
claude> /building-agents-construction
Enter goal: "Build an agent that processes customer support tickets"
```
+26
View File
@@ -0,0 +1,26 @@
.PHONY: lint format check test install-hooks help
help: ## Show this help
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \
awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
lint: ## Run ruff linter (with auto-fix)
cd core && ruff check --fix .
cd tools && ruff check --fix .
format: ## Run ruff formatter
cd core && ruff format .
cd tools && ruff format .
check: ## Run all checks without modifying files (CI-safe)
cd core && ruff check .
cd tools && ruff check .
cd core && ruff format --check .
cd tools && ruff format --check .
test: ## Run all tests
cd core && python -m pytest tests/ -v
install-hooks: ## Install pre-commit hooks
pip install pre-commit
pre-commit install
+29 -16
View File
@@ -4,12 +4,12 @@
<p align="center">
<a href="README.md">English</a> |
<a href="README.zh-CN.md">简体中文</a> |
<a href="README.es.md">Español</a> |
<a href="README.pt.md">Português</a> |
<a href="README.ja.md">日本語</a> |
<a href="README.ru.md">Русский</a> |
<a href="README.ko.md">한국어</a>
<a href="docs/i18n/zh-CN.md">简体中文</a> |
<a href="docs/i18n/es.md">Español</a> |
<a href="docs/i18n/pt.md">Português</a> |
<a href="docs/i18n/ja.md">日本語</a> |
<a href="docs/i18n/ru.md">Русский</a> |
<a href="docs/i18n/ko.md">한국어</a>
</p>
[![Apache 2.0 License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/adenhq/hive/blob/main/LICENSE)
@@ -91,7 +91,7 @@ This installs:
./quickstart.sh
# Build an agent using Claude Code
claude> /building-agents
claude> /building-agents-construction
# Test your agent
claude> /testing-agent
@@ -102,6 +102,15 @@ PYTHONPATH=core:exports python -m your_agent_name run --input '{...}'
**[📖 Complete Setup Guide](ENVIRONMENT_SETUP.md)** - Detailed instructions for agent development
### Cursor IDE Support
Skills are also available in Cursor. To enable:
1. Open Command Palette (`Cmd+Shift+P` / `Ctrl+Shift+P`)
2. Run `MCP: Enable` to enable MCP servers
3. Restart Cursor to load the MCP servers from `.cursor/mcp.json`
4. Type `/` in Agent chat and search for skills (e.g., `/building-agents-construction`)
## Features
- **Goal-Driven Development** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
@@ -226,6 +235,7 @@ hive/
├── docs/ # Documentation and guides
├── scripts/ # Build and utility scripts
├── .claude/ # Claude Code skills for building agents
├── .cursor/ # Cursor IDE skills (symlinks to .claude/skills)
├── ENVIRONMENT_SETUP.md # Python setup guide for agent development
├── DEVELOPER.md # Developer guide
├── CONTRIBUTING.md # Contribution guidelines
@@ -248,7 +258,7 @@ For building and running goal-driven agents with the framework:
# - All dependencies
# Build new agents using Claude Code skills
claude> /building-agents
claude> /building-agents-construction
# Test agents
claude> /testing-agent
@@ -264,7 +274,7 @@ See [ENVIRONMENT_SETUP.md](ENVIRONMENT_SETUP.md) for complete setup instructions
- **[Developer Guide](DEVELOPER.md)** - Comprehensive guide for developers
- [Getting Started](docs/getting-started.md) - Quick setup instructions
- [Configuration Guide](docs/configuration.md) - All configuration options
- [Architecture Overview](docs/architecture.md) - System design and structure
- [Architecture Overview](docs/architecture/README.md) - System design and structure
## Roadmap
@@ -300,11 +310,14 @@ We use [Discord](https://discord.com/invite/MXE49hrKDk) for support, feature req
We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
1. Fork the repository
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add amazing feature'`)
4. Push to the branch (`git push origin feature/amazing-feature`)
5. Open a Pull Request
**Important:** Please get assigned to an issue before submitting a PR. Comment on an issue to claim it, and a maintainer will assign you within 24 hours. This helps prevent duplicate work.
1. Find or create an issue and get assigned
2. Fork the repository
3. Create your feature branch (`git checkout -b feature/amazing-feature`)
4. Commit your changes (`git commit -m 'Add amazing feature'`)
5. Push to the branch (`git push origin feature/amazing-feature`)
6. Open a Pull Request
## Join Our Team
@@ -328,7 +341,7 @@ No. Aden is built from the ground up with no dependencies on LangChain, CrewAI,
**Q: What LLM providers does Aden support?**
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
**Q: Can I use Aden with local AI models like Ollama?**
@@ -348,7 +361,7 @@ Aden collects telemetry data for monitoring and observability purposes, includin
**Q: What deployment options does Aden support?**
Aden supports Docker Compose deployment out of the box, with both production and development configurations. Self-hosted deployments work on any infrastructure supporting Docker. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
Aden supports self-hosted deployments via Python packages. See the [Environment Setup Guide](ENVIRONMENT_SETUP.md) for installation instructions. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
**Q: Can Aden handle complex, production-scale use cases?**
+1 -1
View File
@@ -1,4 +1,4 @@
Product Roadmap
# Product Roadmap
Aden Agent Framework aims to help developers build outcome oriented, self-adaptive agents. Please find our roadmap here
+1 -1
View File
@@ -145,7 +145,7 @@ python -m framework test-debug <agent_path> <test_name>
python -m framework test-list <goal_id>
```
For detailed testing workflows, see the [testing-agent skill](.claude/skills/testing-agent/SKILL.md).
For detailed testing workflows, see the [testing-agent skill](../.claude/skills/testing-agent/SKILL.md).
### Analyzing Agent Behavior with Builder
+123
View File
@@ -0,0 +1,123 @@
"""
Minimal Manual Agent Example
----------------------------
This example demonstrates how to build and run an agent programmatically
without using the Claude Code CLI or external LLM APIs.
It uses 'function' nodes to define logic in pure Python, making it perfect
for understanding the core runtime loop:
Setup -> Graph definition -> Execution -> Result
Run with:
PYTHONPATH=core python core/examples/manual_agent.py
"""
import asyncio
from framework.graph import EdgeCondition, EdgeSpec, Goal, GraphSpec, NodeSpec
from framework.graph.executor import GraphExecutor
from framework.runtime.core import Runtime
# 1. Define Node Logic (Pure Python Functions)
def greet(name: str) -> str:
"""Generate a simple greeting."""
return f"Hello, {name}!"
def uppercase(greeting: str) -> str:
"""Convert text to uppercase."""
return greeting.upper()
async def main():
print("🚀 Setting up Manual Agent...")
# 2. Define the Goal
# Every agent needs a goal with success criteria
goal = Goal(
id="greet-user",
name="Greet User",
description="Generate a friendly uppercase greeting",
success_criteria=[
{
"id": "greeting_generated",
"description": "Greeting produced",
"metric": "custom",
"target": "any",
}
],
)
# 3. Define Nodes
# Nodes describe steps in the process
node1 = NodeSpec(
id="greeter",
name="Greeter",
description="Generates a simple greeting",
node_type="function",
function="greet", # Matches the registered function name
input_keys=["name"],
output_keys=["greeting"],
)
node2 = NodeSpec(
id="uppercaser",
name="Uppercaser",
description="Converts greeting to uppercase",
node_type="function",
function="uppercase",
input_keys=["greeting"],
output_keys=["final_greeting"],
)
# 4. Define Edges
# Edges define the flow between nodes
edge1 = EdgeSpec(
id="greet-to-upper",
source="greeter",
target="uppercaser",
condition=EdgeCondition.ON_SUCCESS,
)
# 5. Create Graph
# The graph works like a blueprint connecting nodes and edges
graph = GraphSpec(
id="greeting-agent",
goal_id="greet-user",
entry_node="greeter",
terminal_nodes=["uppercaser"],
nodes=[node1, node2],
edges=[edge1],
)
# 6. Initialize Runtime & Executor
# Runtime handles state/memory; Executor runs the graph
from pathlib import Path
runtime = Runtime(storage_path=Path("./agent_logs"))
executor = GraphExecutor(runtime=runtime)
# 7. Register Function Implementations
# Connect string names in NodeSpecs to actual Python functions
executor.register_function("greeter", greet)
executor.register_function("uppercaser", uppercase)
# 8. Execute Agent
print("▶ Executing agent with input: name='Alice'...")
result = await executor.execute(graph=graph, goal=goal, input_data={"name": "Alice"})
# 9. Verify Results
if result.success:
print("\n✅ Success!")
print(f"Path taken: {' -> '.join(result.path)}")
print(f"Final output: {result.output.get('final_greeting')}")
else:
print(f"\n❌ Failed: {result.error}")
if __name__ == "__main__":
# Optional: Enable logging to see internal decision flow
# logging.basicConfig(level=logging.INFO)
asyncio.run(main())
+9 -14
View File
@@ -37,9 +37,9 @@ async def example_1_programmatic_registration():
print(f"\nAvailable tools: {list(tools.keys())}")
# Run the agent with MCP tools available
result = await runner.run({
"objective": "Search for 'Claude AI' and summarize the top 3 results"
})
result = await runner.run(
{"objective": "Search for 'Claude AI' and summarize the top 3 results"}
)
print(f"\nAgent result: {result}")
@@ -78,10 +78,8 @@ async def example_3_config_file():
# Copy example config (in practice, you'd place this in your agent folder)
import shutil
shutil.copy(
"examples/mcp_servers.json",
test_agent_path / "mcp_servers.json"
)
shutil.copy("examples/mcp_servers.json", test_agent_path / "mcp_servers.json")
# Load agent - MCP servers will be auto-discovered
runner = AgentRunner.load(test_agent_path)
@@ -110,18 +108,14 @@ async def example_4_custom_agent_with_mcp_tools():
builder.set_goal(
goal_id="web-researcher",
name="Web Research Agent",
description="Search the web and summarize findings"
description="Search the web and summarize findings",
)
# Add success criteria
builder.add_success_criterion(
"search-results",
"Successfully retrieve at least 3 web search results"
)
builder.add_success_criterion(
"summary",
"Provide a clear, concise summary of the findings"
"search-results", "Successfully retrieve at least 3 web search results"
)
builder.add_success_criterion("summary", "Provide a clear, concise summary of the findings")
# Add nodes that will use MCP tools
builder.add_node(
@@ -192,6 +186,7 @@ async def main():
except Exception as e:
print(f"\nError running example: {e}")
import traceback
traceback.print_exc()
+9 -9
View File
@@ -22,22 +22,22 @@ The framework includes a Goal-Based Testing system (Goal → Agent → Eval):
See `framework.testing` for details.
"""
from framework.schemas.decision import Decision, Option, Outcome, DecisionEvaluation
from framework.schemas.run import Run, RunSummary, Problem
from framework.runtime.core import Runtime
from framework.builder.query import BuilderQuery
from framework.llm import LLMProvider, AnthropicProvider
from framework.runner import AgentRunner, AgentOrchestrator
from framework.llm import AnthropicProvider, LLMProvider
from framework.runner import AgentOrchestrator, AgentRunner
from framework.runtime.core import Runtime
from framework.schemas.decision import Decision, DecisionEvaluation, Option, Outcome
from framework.schemas.run import Problem, Run, RunSummary
# Testing framework
from framework.testing import (
ApprovalStatus,
DebugTool,
ErrorCategory,
Test,
TestResult,
TestSuiteResult,
TestStorage,
ApprovalStatus,
ErrorCategory,
DebugTool,
TestSuiteResult,
)
__all__ = [
+3 -3
View File
@@ -2,12 +2,12 @@
from framework.builder.query import BuilderQuery
from framework.builder.workflow import (
GraphBuilder,
BuildSession,
BuildPhase,
ValidationResult,
BuildSession,
GraphBuilder,
TestCase,
TestResult,
ValidationResult,
)
__all__ = [
+46 -42
View File
@@ -8,12 +8,12 @@ This is designed around the questions I need to answer:
4. What should we change? (suggestions)
"""
from typing import Any
from collections import defaultdict
from pathlib import Path
from typing import Any
from framework.schemas.decision import Decision
from framework.schemas.run import Run, RunSummary, RunStatus
from framework.schemas.run import Run, RunStatus, RunSummary
from framework.storage.backend import FileStorage
@@ -196,10 +196,7 @@ class BuilderQuery:
break
# Extract problems
problems = [
f"[{p.severity}] {p.description}"
for p in run.problems
]
problems = [f"[{p.severity}] {p.description}" for p in run.problems]
# Generate suggestions based on the failure
suggestions = self._generate_suggestions(run, failed_decisions)
@@ -253,11 +250,7 @@ class BuilderQuery:
error = decision.outcome.error or "Unknown error"
failure_counts[error] += 1
common_failures = sorted(
failure_counts.items(),
key=lambda x: x[1],
reverse=True
)[:5]
common_failures = sorted(failure_counts.items(), key=lambda x: x[1], reverse=True)[:5]
# Find problematic nodes
node_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"total": 0, "failed": 0})
@@ -328,34 +321,45 @@ class BuilderQuery:
# Suggestion: Fix problematic nodes
for node_id, failure_rate in patterns.problematic_nodes:
suggestions.append({
"type": "node_improvement",
"target": node_id,
"reason": f"Node has {failure_rate:.1%} failure rate",
"recommendation": f"Review and improve node '{node_id}' - high failure rate suggests prompt or tool issues",
"priority": "high" if failure_rate > 0.3 else "medium",
})
suggestions.append(
{
"type": "node_improvement",
"target": node_id,
"reason": f"Node has {failure_rate:.1%} failure rate",
"recommendation": (
f"Review and improve node '{node_id}' - "
"high failure rate suggests prompt or tool issues"
),
"priority": "high" if failure_rate > 0.3 else "medium",
}
)
# Suggestion: Address common failures
for failure, count in patterns.common_failures:
if count >= 2:
suggestions.append({
"type": "error_handling",
"target": failure,
"reason": f"Error occurred {count} times",
"recommendation": f"Add handling for: {failure}",
"priority": "high" if count >= 5 else "medium",
})
suggestions.append(
{
"type": "error_handling",
"target": failure,
"reason": f"Error occurred {count} times",
"recommendation": f"Add handling for: {failure}",
"priority": "high" if count >= 5 else "medium",
}
)
# Suggestion: Overall success rate
if patterns.success_rate < 0.8:
suggestions.append({
"type": "architecture",
"target": goal_id,
"reason": f"Goal success rate is only {patterns.success_rate:.1%}",
"recommendation": "Consider restructuring the agent graph or improving goal definition",
"priority": "high",
})
suggestions.append(
{
"type": "architecture",
"target": goal_id,
"reason": f"Goal success rate is only {patterns.success_rate:.1%}",
"recommendation": (
"Consider restructuring the agent graph or improving goal definition"
),
"priority": "high",
}
)
return suggestions
@@ -408,21 +412,22 @@ class BuilderQuery:
alternatives = [o for o in decision.options if o.id != decision.chosen_option_id]
if alternatives:
alt_desc = alternatives[0].description
chosen_desc = chosen.description if chosen else "unknown"
suggestions.append(
f"Consider alternative: '{alt_desc}' instead of '{chosen.description if chosen else 'unknown'}'"
f"Consider alternative: '{alt_desc}' instead of '{chosen_desc}'"
)
# Check for missing context
if not decision.input_context:
suggestions.append(
f"Decision '{decision.intent}' had no input context - ensure relevant data is passed"
f"Decision '{decision.intent}' had no input context - "
"ensure relevant data is passed"
)
# Check for constraint issues
if decision.active_constraints:
suggestions.append(
f"Review constraints: {', '.join(decision.active_constraints)} - may be too restrictive"
)
constraints = ", ".join(decision.active_constraints)
suggestions.append(f"Review constraints: {constraints} - may be too restrictive")
# Check for reported problems with suggestions
for problem in run.problems:
@@ -471,15 +476,14 @@ class BuilderQuery:
# Decision count difference
if len(run1.decisions) != len(run2.decisions):
differences.append(
f"Decision count: {len(run1.decisions)} vs {len(run2.decisions)}"
)
differences.append(f"Decision count: {len(run1.decisions)} vs {len(run2.decisions)}")
# Find first divergence point
for i, (d1, d2) in enumerate(zip(run1.decisions, run2.decisions)):
for i, (d1, d2) in enumerate(zip(run1.decisions, run2.decisions, strict=False)):
if d1.chosen_option_id != d2.chosen_option_id:
differences.append(
f"Diverged at decision {i}: chose '{d1.chosen_option_id}' vs '{d2.chosen_option_id}'"
f"Diverged at decision {i}: "
f"chose '{d1.chosen_option_id}' vs '{d2.chosen_option_id}'"
)
break
+92 -69
View File
@@ -13,32 +13,35 @@ Each step requires validation and human approval before proceeding.
You cannot skip steps or bypass validation.
"""
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from pathlib import Path
from datetime import datetime
from typing import Any, Callable
from typing import Any
from pydantic import BaseModel, Field
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.goal import Goal
from framework.graph.node import NodeSpec
from framework.graph.edge import EdgeSpec, EdgeCondition, GraphSpec
class BuildPhase(str, Enum):
"""Current phase of the build process."""
INIT = "init" # Just started
GOAL_DRAFT = "goal_draft" # Drafting goal
INIT = "init" # Just started
GOAL_DRAFT = "goal_draft" # Drafting goal
GOAL_APPROVED = "goal_approved" # Goal approved
ADDING_NODES = "adding_nodes" # Adding nodes
ADDING_EDGES = "adding_edges" # Adding edges
TESTING = "testing" # Running tests
APPROVED = "approved" # Fully approved
EXPORTED = "exported" # Exported to file
ADDING_NODES = "adding_nodes" # Adding nodes
ADDING_EDGES = "adding_edges" # Adding edges
TESTING = "testing" # Running tests
APPROVED = "approved" # Fully approved
EXPORTED = "exported" # Exported to file
class ValidationResult(BaseModel):
"""Result of a validation check."""
valid: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
@@ -47,6 +50,7 @@ class ValidationResult(BaseModel):
class TestCase(BaseModel):
"""A test case for validating agent behavior."""
id: str
description: str
input: dict[str, Any]
@@ -56,6 +60,7 @@ class TestCase(BaseModel):
class TestResult(BaseModel):
"""Result of running a test case."""
test_id: str
passed: bool
actual_output: Any = None
@@ -69,6 +74,7 @@ class BuildSession(BaseModel):
Saved after each approved step so you can resume later.
"""
id: str
name: str
phase: BuildPhase = BuildPhase.INIT
@@ -457,11 +463,14 @@ class GraphBuilder:
# Run the test
import asyncio
result = asyncio.run(executor.execute(
graph=graph,
goal=self.session.goal,
input_data=test.input,
))
result = asyncio.run(
executor.execute(
graph=graph,
goal=self.session.goal,
input_data=test.input,
)
)
# Check result
passed = result.success
@@ -515,12 +524,14 @@ class GraphBuilder:
if not self._pending_validation.valid:
return False
self.session.approvals.append({
"phase": self.session.phase.value,
"comment": comment,
"timestamp": datetime.now().isoformat(),
"validation": self._pending_validation.model_dump(),
})
self.session.approvals.append(
{
"phase": self.session.phase.value,
"comment": comment,
"timestamp": datetime.now().isoformat(),
"validation": self._pending_validation.model_dump(),
}
)
# Advance phase if appropriate
if self.session.phase == BuildPhase.GOAL_DRAFT:
@@ -554,11 +565,13 @@ class GraphBuilder:
return False
self.session.phase = BuildPhase.APPROVED
self.session.approvals.append({
"phase": "final",
"comment": comment,
"timestamp": datetime.now().isoformat(),
})
self.session.approvals.append(
{
"phase": "final",
"comment": comment,
"timestamp": datetime.now().isoformat(),
}
)
self._save_session()
return True
@@ -630,69 +643,75 @@ class GraphBuilder:
"""Generate Python code for the graph."""
lines = [
'"""',
f'Generated agent: {self.session.name}',
f'Generated at: {datetime.now().isoformat()}',
f"Generated agent: {self.session.name}",
f"Generated at: {datetime.now().isoformat()}",
'"""',
'',
'from framework.graph import (',
' Goal, SuccessCriterion, Constraint,',
' NodeSpec, EdgeSpec, EdgeCondition,',
')',
'from framework.graph.edge import GraphSpec',
'from framework.graph.goal import GoalStatus',
'',
'',
'# Goal',
"",
"from framework.graph import (",
" Goal, SuccessCriterion, Constraint,",
" NodeSpec, EdgeSpec, EdgeCondition,",
")",
"from framework.graph.edge import GraphSpec",
"from framework.graph.goal import GoalStatus",
"",
"",
"# Goal",
]
if self.session.goal:
goal_json = self.session.goal.model_dump_json(indent=4)
lines.append('GOAL = Goal.model_validate_json(\'\'\'')
lines.append("GOAL = Goal.model_validate_json('''")
lines.append(goal_json)
lines.append("''')")
else:
lines.append('GOAL = None')
lines.append("GOAL = None")
lines.extend([
'',
'',
'# Nodes',
'NODES = [',
])
lines.extend(
[
"",
"",
"# Nodes",
"NODES = [",
]
)
for node in self.session.nodes:
node_json = node.model_dump_json(indent=4)
lines.append(' NodeSpec.model_validate_json(\'\'\'')
lines.append(" NodeSpec.model_validate_json('''")
lines.append(node_json)
lines.append(" '''),")
lines.extend([
']',
'',
'',
'# Edges',
'EDGES = [',
])
lines.extend(
[
"]",
"",
"",
"# Edges",
"EDGES = [",
]
)
for edge in self.session.edges:
edge_json = edge.model_dump_json(indent=4)
lines.append(' EdgeSpec.model_validate_json(\'\'\'')
lines.append(" EdgeSpec.model_validate_json('''")
lines.append(edge_json)
lines.append(" '''),")
lines.extend([
']',
'',
'',
'# Graph',
])
lines.extend(
[
"]",
"",
"",
"# Graph",
]
)
graph_json = graph.model_dump_json(indent=4)
lines.append('GRAPH = GraphSpec.model_validate_json(\'\'\'')
lines.append("GRAPH = GraphSpec.model_validate_json('''")
lines.append(graph_json)
lines.append("''')")
return '\n'.join(lines)
return "\n".join(lines)
# =========================================================================
# SESSION MANAGEMENT
@@ -743,7 +762,9 @@ class GraphBuilder:
"tests": len(self.session.test_cases),
"tests_passed": sum(1 for t in self.session.test_results if t.passed),
"approvals": len(self.session.approvals),
"pending_validation": self._pending_validation.model_dump() if self._pending_validation else None,
"pending_validation": self._pending_validation.model_dump()
if self._pending_validation
else None,
}
def show(self) -> str:
@@ -755,11 +776,13 @@ class GraphBuilder:
]
if self.session.goal:
lines.extend([
f"Goal: {self.session.goal.name}",
f" {self.session.goal.description}",
"",
])
lines.extend(
[
f"Goal: {self.session.goal.name}",
f" {self.session.goal.description}",
"",
]
)
if self.session.nodes:
lines.append("Nodes:")
+3 -3
View File
@@ -21,9 +21,7 @@ import sys
def main():
parser = argparse.ArgumentParser(
description="Goal Agent - Build and run goal-driven agents"
)
parser = argparse.ArgumentParser(description="Goal Agent - Build and run goal-driven agents")
parser.add_argument(
"--model",
default="claude-haiku-4-5-20251001",
@@ -34,10 +32,12 @@ def main():
# Register runner commands (run, info, validate, list, dispatch, shell)
from framework.runner.cli import register_commands
register_commands(subparsers)
# Register testing commands (test-run, test-debug, test-list, test-stats)
from framework.testing.cli import register_testing_commands
register_testing_commands(subparsers)
args = parser.parse_args()
+122
View File
@@ -0,0 +1,122 @@
"""
Credential Store - Production-ready credential management for Hive.
This module provides secure credential storage with:
- Key-vault structure: Credentials as objects with multiple keys
- Template-based usage: {{cred.key}} patterns for injection
- Bipartisan model: Store stores values, tools define usage
- Provider system: Extensible lifecycle management (refresh, validate)
- Multiple backends: Encrypted files, env vars, HashiCorp Vault
Quick Start:
from core.framework.credentials import CredentialStore, CredentialObject
# Create store with encrypted storage
store = CredentialStore.with_encrypted_storage("/var/hive/credentials")
# Get a credential
api_key = store.get("brave_search")
# Resolve templates in headers
headers = store.resolve_headers({
"Authorization": "Bearer {{github_oauth.access_token}}"
})
# Save a new credential
store.save_credential(CredentialObject(
id="my_api",
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("xxx"))}
))
For OAuth2 support:
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config
For Aden server sync:
from core.framework.credentials.aden import (
AdenCredentialClient,
AdenClientConfig,
AdenSyncProvider,
)
For Vault integration:
from core.framework.credentials.vault import HashiCorpVaultStorage
"""
from .models import (
CredentialDecryptionError,
CredentialError,
CredentialKey,
CredentialKeyNotFoundError,
CredentialNotFoundError,
CredentialObject,
CredentialRefreshError,
CredentialType,
CredentialUsageSpec,
CredentialValidationError,
)
from .provider import (
BearerTokenProvider,
CredentialProvider,
StaticProvider,
)
from .storage import (
CompositeStorage,
CredentialStorage,
EncryptedFileStorage,
EnvVarStorage,
InMemoryStorage,
)
from .store import CredentialStore
from .template import TemplateResolver
# Aden sync components (lazy import to avoid httpx dependency when not needed)
# Usage: from core.framework.credentials.aden import AdenSyncProvider
# Or: from core.framework.credentials import AdenSyncProvider
try:
from .aden import (
AdenCachedStorage,
AdenClientConfig,
AdenCredentialClient,
AdenSyncProvider,
)
_ADEN_AVAILABLE = True
except ImportError:
_ADEN_AVAILABLE = False
__all__ = [
# Main store
"CredentialStore",
# Models
"CredentialObject",
"CredentialKey",
"CredentialType",
"CredentialUsageSpec",
# Providers
"CredentialProvider",
"StaticProvider",
"BearerTokenProvider",
# Storage backends
"CredentialStorage",
"EncryptedFileStorage",
"EnvVarStorage",
"InMemoryStorage",
"CompositeStorage",
# Template resolution
"TemplateResolver",
# Exceptions
"CredentialError",
"CredentialNotFoundError",
"CredentialKeyNotFoundError",
"CredentialRefreshError",
"CredentialValidationError",
"CredentialDecryptionError",
# Aden sync (optional - requires httpx)
"AdenSyncProvider",
"AdenCredentialClient",
"AdenClientConfig",
"AdenCachedStorage",
]
# Track Aden availability for runtime checks
ADEN_AVAILABLE = _ADEN_AVAILABLE
@@ -0,0 +1,76 @@
"""
Aden Credential Sync.
Components for synchronizing credentials with the Aden authentication server.
The Aden server handles OAuth2 authorization flows and maintains refresh tokens.
These components fetch and cache access tokens locally while delegating
lifecycle management to Aden.
Components:
- AdenCredentialClient: HTTP client for Aden API
- AdenSyncProvider: CredentialProvider that syncs with Aden
- AdenCachedStorage: Storage with local cache + Aden fallback
Quick Start:
from core.framework.credentials import CredentialStore
from core.framework.credentials.storage import EncryptedFileStorage
from core.framework.credentials.aden import (
AdenCredentialClient,
AdenClientConfig,
AdenSyncProvider,
)
# Configure (API key loaded from ADEN_API_KEY env var)
client = AdenCredentialClient(AdenClientConfig(
base_url=os.environ["ADEN_API_URL"],
))
provider = AdenSyncProvider(client=client)
store = CredentialStore(
storage=EncryptedFileStorage(),
providers=[provider],
auto_refresh=True,
)
# Initial sync
provider.sync_all(store)
# Use normally
token = store.get_key("hubspot", "access_token")
See docs/aden-credential-sync.md for detailed documentation.
"""
from .client import (
AdenAuthenticationError,
AdenClientConfig,
AdenClientError,
AdenCredentialClient,
AdenCredentialResponse,
AdenIntegrationInfo,
AdenNotFoundError,
AdenRateLimitError,
AdenRefreshError,
)
from .provider import AdenSyncProvider
from .storage import AdenCachedStorage
__all__ = [
# Client
"AdenCredentialClient",
"AdenClientConfig",
"AdenCredentialResponse",
"AdenIntegrationInfo",
# Client errors
"AdenClientError",
"AdenAuthenticationError",
"AdenNotFoundError",
"AdenRateLimitError",
"AdenRefreshError",
# Provider
"AdenSyncProvider",
# Storage
"AdenCachedStorage",
]
+467
View File
@@ -0,0 +1,467 @@
"""
Aden Credential Client.
HTTP client for communicating with the Aden authentication server.
The Aden server handles OAuth2 authorization flows and token management.
This client fetches tokens and delegates refresh operations to Aden.
Usage:
# API key loaded from ADEN_API_KEY environment variable by default
client = AdenCredentialClient(AdenClientConfig(
base_url="https://hive.adenhq.com",
))
# Or explicitly provide the API key
client = AdenCredentialClient(AdenClientConfig(
base_url="https://hive.adenhq.com",
api_key="your-api-key",
))
# Fetch a credential
response = client.get_credential("hubspot")
if response:
print(f"Token expires at: {response.expires_at}")
# Request a refresh
refreshed = client.request_refresh("hubspot")
"""
from __future__ import annotations
import logging
import os
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class AdenClientError(Exception):
"""Base exception for Aden client errors."""
pass
class AdenAuthenticationError(AdenClientError):
"""Raised when API key is invalid or revoked."""
pass
class AdenNotFoundError(AdenClientError):
"""Raised when integration is not found."""
pass
class AdenRefreshError(AdenClientError):
"""Raised when token refresh fails."""
def __init__(
self,
message: str,
requires_reauthorization: bool = False,
reauthorization_url: str | None = None,
):
super().__init__(message)
self.requires_reauthorization = requires_reauthorization
self.reauthorization_url = reauthorization_url
class AdenRateLimitError(AdenClientError):
"""Raised when rate limited."""
def __init__(self, message: str, retry_after: int = 60):
super().__init__(message)
self.retry_after = retry_after
@dataclass
class AdenClientConfig:
"""Configuration for Aden API client."""
base_url: str
"""Base URL of the Aden server (e.g., 'https://hive.adenhq.com')."""
api_key: str | None = None
"""Agent's API key for authenticating with Aden.
If not provided, loaded from ADEN_API_KEY environment variable."""
tenant_id: str | None = None
"""Optional tenant ID for multi-tenant deployments."""
timeout: float = 30.0
"""Request timeout in seconds."""
retry_attempts: int = 3
"""Number of retry attempts for transient failures."""
retry_delay: float = 1.0
"""Base delay between retries in seconds (exponential backoff)."""
def __post_init__(self) -> None:
"""Load API key from environment if not provided."""
if self.api_key is None:
self.api_key = os.environ.get("ADEN_API_KEY")
if not self.api_key:
raise ValueError(
"Aden API key not provided. Either pass api_key to AdenClientConfig "
"or set the ADEN_API_KEY environment variable."
)
@dataclass
class AdenCredentialResponse:
"""Response from Aden server containing credential data."""
integration_id: str
"""Unique identifier for the integration (e.g., 'hubspot')."""
integration_type: str
"""Type of integration (e.g., 'hubspot', 'github', 'slack')."""
access_token: str
"""The access token for API calls."""
token_type: str = "Bearer"
"""Token type (usually 'Bearer')."""
expires_at: datetime | None = None
"""When the access token expires (UTC)."""
scopes: list[str] = field(default_factory=list)
"""OAuth2 scopes granted to this token."""
metadata: dict[str, Any] = field(default_factory=dict)
"""Additional integration-specific metadata."""
@classmethod
def from_dict(cls, data: dict[str, Any]) -> AdenCredentialResponse:
"""Create from API response dictionary."""
expires_at = None
if data.get("expires_at"):
expires_at = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
return cls(
integration_id=data["integration_id"],
integration_type=data["integration_type"],
access_token=data["access_token"],
token_type=data.get("token_type", "Bearer"),
expires_at=expires_at,
scopes=data.get("scopes", []),
metadata=data.get("metadata", {}),
)
@dataclass
class AdenIntegrationInfo:
"""Information about an available integration."""
integration_id: str
integration_type: str
status: str # "active", "requires_reauth", "expired"
expires_at: datetime | None = None
@classmethod
def from_dict(cls, data: dict[str, Any]) -> AdenIntegrationInfo:
"""Create from API response dictionary."""
expires_at = None
if data.get("expires_at"):
expires_at = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
return cls(
integration_id=data["integration_id"],
integration_type=data.get("provider", data["integration_id"]),
status=data.get("status", "unknown"),
expires_at=expires_at,
)
class AdenCredentialClient:
"""
HTTP client for Aden credential server.
Handles communication with the Aden authentication server,
including fetching credentials, requesting refreshes, and
reporting usage statistics.
The client automatically handles:
- Retries with exponential backoff for transient failures
- Proper error classification (auth, not found, rate limit, etc.)
- Request headers for authentication and tenant isolation
Usage:
# API key loaded from ADEN_API_KEY environment variable
config = AdenClientConfig(
base_url="https://hive.adenhq.com",
)
client = AdenCredentialClient(config)
# Fetch a credential
cred = client.get_credential("hubspot")
if cred:
headers = {"Authorization": f"Bearer {cred.access_token}"}
# List all integrations
integrations = client.list_integrations()
for info in integrations:
print(f"{info.integration_id}: {info.status}")
# Clean up
client.close()
"""
def __init__(self, config: AdenClientConfig):
"""
Initialize the Aden client.
Args:
config: Client configuration including base URL and API key.
"""
self.config = config
self._client: httpx.Client | None = None
def _get_client(self) -> httpx.Client:
"""Get or create the HTTP client."""
if self._client is None:
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
"User-Agent": "hive-credential-store/1.0",
}
if self.config.tenant_id:
headers["X-Tenant-ID"] = self.config.tenant_id
self._client = httpx.Client(
base_url=self.config.base_url,
timeout=self.config.timeout,
headers=headers,
)
return self._client
def _request_with_retry(
self,
method: str,
path: str,
**kwargs: Any,
) -> httpx.Response:
"""Make a request with retry logic."""
client = self._get_client()
print(client.base_url)
print(client.headers)
print(method, path, kwargs)
last_error: Exception | None = None
for attempt in range(self.config.retry_attempts):
try:
response = client.request(method, path, **kwargs)
# Handle specific error codes
if response.status_code == 401:
raise AdenAuthenticationError("Agent API key is invalid or revoked")
if response.status_code == 404:
raise AdenNotFoundError(f"Integration not found: {path}")
if response.status_code == 429:
retry_after = int(response.headers.get("Retry-After", 60))
raise AdenRateLimitError(
"Rate limited by Aden server",
retry_after=retry_after,
)
if response.status_code == 400:
data = response.json()
if data.get("error") == "refresh_failed":
raise AdenRefreshError(
data.get("message", "Token refresh failed"),
requires_reauthorization=data.get("requires_reauthorization", False),
reauthorization_url=data.get("reauthorization_url"),
)
# Success or other error
response.raise_for_status()
return response
except (httpx.ConnectError, httpx.TimeoutException) as e:
last_error = e
if attempt < self.config.retry_attempts - 1:
delay = self.config.retry_delay * (2**attempt)
logger.warning(
f"Aden request failed (attempt {attempt + 1}), retrying in {delay}s: {e}"
)
time.sleep(delay)
else:
raise AdenClientError(f"Failed to connect to Aden server: {e}") from e
except (
AdenAuthenticationError,
AdenNotFoundError,
AdenRefreshError,
AdenRateLimitError,
):
# Don't retry these errors
raise
# Should not reach here, but just in case
raise AdenClientError(
f"Request failed after {self.config.retry_attempts} attempts"
) from last_error
def get_credential(self, integration_id: str) -> AdenCredentialResponse | None:
"""
Fetch the current credential for an integration.
The Aden server may refresh the token internally if it's expired
before returning it.
Args:
integration_id: The integration identifier (e.g., 'hubspot').
Returns:
Credential response with access token, or None if not found.
Raises:
AdenAuthenticationError: If API key is invalid.
AdenClientError: For connection failures.
"""
try:
response = self._request_with_retry("GET", f"/v1/credentials/{integration_id}")
data = response.json()
return AdenCredentialResponse.from_dict(data)
except AdenNotFoundError:
return None
def request_refresh(self, integration_id: str) -> AdenCredentialResponse:
"""
Request the Aden server to refresh the token.
Use this when the local store detects an expired or near-expiry token.
The Aden server handles the actual OAuth2 refresh token flow.
Args:
integration_id: The integration identifier.
Returns:
Credential response with new access token.
Raises:
AdenRefreshError: If refresh fails (may require re-authorization).
AdenNotFoundError: If integration not found.
AdenAuthenticationError: If API key is invalid.
AdenRateLimitError: If rate limited.
"""
response = self._request_with_retry("POST", f"/v1/credentials/{integration_id}/refresh")
data = response.json()
return AdenCredentialResponse.from_dict(data)
def list_integrations(self) -> list[AdenIntegrationInfo]:
"""
List all integrations available for this agent/tenant.
Returns:
List of integration info objects.
Raises:
AdenAuthenticationError: If API key is invalid.
AdenClientError: For connection failures.
"""
response = self._request_with_retry("GET", "/v1/credentials")
data = response.json()
return [AdenIntegrationInfo.from_dict(item) for item in data.get("integrations", [])]
def validate_token(self, integration_id: str) -> dict[str, Any]:
"""
Check if a token is still valid without fetching it.
Args:
integration_id: The integration identifier.
Returns:
Dict with 'valid' bool and optional 'expires_at', 'reason',
'requires_reauthorization', 'reauthorization_url'.
Raises:
AdenNotFoundError: If integration not found.
AdenAuthenticationError: If API key is invalid.
"""
response = self._request_with_retry("GET", f"/v1/credentials/{integration_id}/validate")
return response.json()
def report_usage(
self,
integration_id: str,
operation: str,
status: str = "success",
metadata: dict[str, Any] | None = None,
) -> None:
"""
Report credential usage statistics to Aden.
This is optional and used for analytics/billing.
Args:
integration_id: The integration identifier.
operation: Operation name (e.g., 'api_call').
status: Operation status ('success', 'error').
metadata: Additional operation metadata.
"""
try:
self._request_with_retry(
"POST",
f"/v1/credentials/{integration_id}/usage",
json={
"operation": operation,
"status": status,
"timestamp": datetime.utcnow().isoformat() + "Z",
"metadata": metadata or {},
},
)
except Exception as e:
# Usage reporting is best-effort, don't fail on errors
logger.warning(f"Failed to report usage for '{integration_id}': {e}")
def health_check(self) -> dict[str, Any]:
"""
Check Aden server health and connectivity.
Returns:
Dict with 'status', 'version', 'timestamp', and optionally 'error'.
"""
try:
client = self._get_client()
response = client.get("/health")
if response.status_code == 200:
data = response.json()
data["latency_ms"] = response.elapsed.total_seconds() * 1000
return data
return {
"status": "degraded",
"error": f"Unexpected status code: {response.status_code}",
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
}
def close(self) -> None:
"""Close the HTTP client and release resources."""
if self._client:
self._client.close()
self._client = None
def __enter__(self) -> AdenCredentialClient:
"""Context manager entry."""
return self
def __exit__(self, *args: Any) -> None:
"""Context manager exit."""
self.close()
+415
View File
@@ -0,0 +1,415 @@
"""
Aden Sync Provider.
Provider that synchronizes credentials with the Aden authentication server.
The Aden server is the authoritative source for OAuth2 tokens - this provider
fetches and caches tokens locally while delegating refresh operations to Aden.
Usage:
from core.framework.credentials import CredentialStore
from core.framework.credentials.storage import EncryptedFileStorage
from core.framework.credentials.aden import (
AdenCredentialClient,
AdenClientConfig,
AdenSyncProvider,
)
# Configure client (API key loaded from ADEN_API_KEY env var)
client = AdenCredentialClient(AdenClientConfig(
base_url=os.environ["ADEN_API_URL"],
))
# Create provider
provider = AdenSyncProvider(client=client)
# Create store
store = CredentialStore(
storage=EncryptedFileStorage(),
providers=[provider],
auto_refresh=True,
)
# Initial sync from Aden
provider.sync_all(store)
# Use normally - auto-refreshes via Aden when needed
token = store.get_key("hubspot", "access_token")
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from pydantic import SecretStr
from ..models import CredentialKey, CredentialObject, CredentialRefreshError, CredentialType
from ..provider import CredentialProvider
from .client import (
AdenClientError,
AdenCredentialClient,
AdenCredentialResponse,
AdenRefreshError,
)
if TYPE_CHECKING:
from ..store import CredentialStore
logger = logging.getLogger(__name__)
class AdenSyncProvider(CredentialProvider):
"""
Provider that synchronizes credentials with the Aden server.
The Aden server handles OAuth2 authorization flows and maintains
refresh tokens. This provider:
- Fetches access tokens from the Aden server
- Delegates token refresh to the Aden server
- Caches tokens locally in the credential store
- Optionally reports usage statistics back to Aden
Key benefits:
- Client secrets never leave the Aden server
- Refresh token security (stored only on Aden)
- Centralized audit logging
- Multi-tenant support
Usage:
client = AdenCredentialClient(AdenClientConfig(
base_url="https://hive.adenhq.com",
api_key=os.environ["ADEN_API_KEY"],
))
provider = AdenSyncProvider(client=client)
store = CredentialStore(
storage=EncryptedFileStorage(),
providers=[provider],
auto_refresh=True,
)
"""
def __init__(
self,
client: AdenCredentialClient,
provider_id: str = "aden_sync",
refresh_buffer_minutes: int = 5,
report_usage: bool = False,
):
"""
Initialize the Aden sync provider.
Args:
client: Configured Aden API client.
provider_id: Unique identifier for this provider instance.
Useful for multi-tenant scenarios (e.g., 'aden_tenant_123').
refresh_buffer_minutes: Minutes before expiry to trigger refresh.
Default is 5 minutes.
report_usage: Whether to report usage statistics to Aden server.
"""
self._client = client
self._provider_id = provider_id
self._refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
self._report_usage = report_usage
@property
def provider_id(self) -> str:
"""Unique identifier for this provider."""
return self._provider_id
@property
def supported_types(self) -> list[CredentialType]:
"""Credential types this provider can manage."""
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
def can_handle(self, credential: CredentialObject) -> bool:
"""
Check if this provider can handle a credential.
Returns True if:
- Credential type is supported (OAUTH2 or BEARER_TOKEN)
- Credential's provider_id matches this provider, OR
- Credential has '_aden_managed' metadata flag
"""
if credential.credential_type not in self.supported_types:
return False
# Check if credential is explicitly linked to this provider
if credential.provider_id == self.provider_id:
return True
# Check for Aden-managed flag in metadata
aden_flag = credential.keys.get("_aden_managed")
if aden_flag and aden_flag.value.get_secret_value() == "true":
return True
return False
def refresh(self, credential: CredentialObject) -> CredentialObject:
"""
Refresh credential by requesting new token from Aden server.
The Aden server handles the actual OAuth2 refresh token flow.
This method simply fetches the result.
Args:
credential: The credential to refresh.
Returns:
Updated credential with new access token.
Raises:
CredentialRefreshError: If refresh fails.
"""
try:
# Request Aden to refresh the token
aden_response = self._client.request_refresh(credential.id)
# Update credential with new values
credential = self._update_credential_from_aden(credential, aden_response)
logger.info(f"Refreshed credential '{credential.id}' via Aden server")
# Report usage if enabled
if self._report_usage:
self._client.report_usage(
integration_id=credential.id,
operation="token_refresh",
status="success",
)
return credential
except AdenRefreshError as e:
logger.error(f"Aden refresh failed for '{credential.id}': {e}")
if e.requires_reauthorization:
raise CredentialRefreshError(
f"Integration '{credential.id}' requires re-authorization. "
f"Visit: {e.reauthorization_url or 'your Aden dashboard'}"
) from e
raise CredentialRefreshError(
f"Failed to refresh credential '{credential.id}': {e}"
) from e
except AdenClientError as e:
logger.error(f"Aden client error for '{credential.id}': {e}")
# Check if local token is still valid
access_key = credential.keys.get("access_token")
if access_key and access_key.expires_at:
if datetime.now(UTC) < access_key.expires_at:
logger.warning(f"Aden unavailable, using cached token for '{credential.id}'")
return credential
raise CredentialRefreshError(
f"Aden server unavailable and token expired for '{credential.id}'"
) from e
def validate(self, credential: CredentialObject) -> bool:
"""
Validate credential via Aden server introspection.
Args:
credential: The credential to validate.
Returns:
True if credential is valid.
"""
try:
result = self._client.validate_token(credential.id)
return result.get("valid", False)
except AdenClientError:
# Fall back to local validation
access_key = credential.keys.get("access_token")
if access_key is None:
return False
if access_key.expires_at is None:
# No expiration - assume valid
return True
return datetime.now(UTC) < access_key.expires_at
def should_refresh(self, credential: CredentialObject) -> bool:
"""
Check if credential should be refreshed.
Returns True if access_token is expired or within the refresh buffer.
Args:
credential: The credential to check.
Returns:
True if credential should be refreshed.
"""
access_key = credential.keys.get("access_token")
if access_key is None:
return False
if access_key.expires_at is None:
return False
# Refresh if within buffer of expiration
return datetime.now(UTC) >= (access_key.expires_at - self._refresh_buffer)
def fetch_from_aden(self, integration_id: str) -> CredentialObject | None:
"""
Fetch credential directly from Aden server.
Use this for initial population or when local cache is missing.
Args:
integration_id: The integration identifier (e.g., 'hubspot').
Returns:
CredentialObject if found, None otherwise.
Raises:
AdenClientError: For connection failures.
"""
aden_response = self._client.get_credential(integration_id)
if aden_response is None:
return None
return self._aden_response_to_credential(aden_response)
def sync_all(self, store: CredentialStore) -> int:
"""
Sync all credentials from Aden server to local store.
Fetches the list of available integrations from Aden and
populates the local credential store with current tokens.
Args:
store: The credential store to populate.
Returns:
Number of credentials synced.
"""
synced = 0
try:
integrations = self._client.list_integrations()
for info in integrations:
if info.status != "active":
logger.warning(
f"Skipping integration '{info.integration_id}': status={info.status}"
)
continue
try:
cred = self.fetch_from_aden(info.integration_id)
if cred:
store.save_credential(cred)
synced += 1
logger.info(f"Synced credential '{info.integration_id}' from Aden")
except Exception as e:
logger.warning(f"Failed to sync '{info.integration_id}': {e}")
except AdenClientError as e:
logger.error(f"Failed to list integrations from Aden: {e}")
return synced
def report_credential_usage(
self,
credential: CredentialObject,
operation: str,
status: str = "success",
metadata: dict | None = None,
) -> None:
"""
Report credential usage to Aden server.
Args:
credential: The credential that was used.
operation: Operation name (e.g., 'api_call').
status: Operation status ('success', 'error').
metadata: Additional metadata.
"""
if self._report_usage:
self._client.report_usage(
integration_id=credential.id,
operation=operation,
status=status,
metadata=metadata or {},
)
def _update_credential_from_aden(
self,
credential: CredentialObject,
aden_response: AdenCredentialResponse,
) -> CredentialObject:
"""Update credential object from Aden response."""
# Update access token
credential.keys["access_token"] = CredentialKey(
name="access_token",
value=SecretStr(aden_response.access_token),
expires_at=aden_response.expires_at,
)
# Update scopes if present
if aden_response.scopes:
credential.keys["scope"] = CredentialKey(
name="scope",
value=SecretStr(" ".join(aden_response.scopes)),
)
# Mark as Aden-managed
credential.keys["_aden_managed"] = CredentialKey(
name="_aden_managed",
value=SecretStr("true"),
)
# Store integration type
credential.keys["_integration_type"] = CredentialKey(
name="_integration_type",
value=SecretStr(aden_response.integration_type),
)
# Update timestamps
credential.last_refreshed = datetime.now(UTC)
credential.provider_id = self.provider_id
return credential
def _aden_response_to_credential(
self,
aden_response: AdenCredentialResponse,
) -> CredentialObject:
"""Convert Aden response to CredentialObject."""
keys: dict[str, CredentialKey] = {
"access_token": CredentialKey(
name="access_token",
value=SecretStr(aden_response.access_token),
expires_at=aden_response.expires_at,
),
"_aden_managed": CredentialKey(
name="_aden_managed",
value=SecretStr("true"),
),
"_integration_type": CredentialKey(
name="_integration_type",
value=SecretStr(aden_response.integration_type),
),
}
if aden_response.scopes:
keys["scope"] = CredentialKey(
name="scope",
value=SecretStr(" ".join(aden_response.scopes)),
)
return CredentialObject(
id=aden_response.integration_id,
credential_type=CredentialType.OAUTH2,
keys=keys,
provider_id=self.provider_id,
auto_refresh=True,
)
+307
View File
@@ -0,0 +1,307 @@
"""
Aden Cached Storage.
Storage backend that combines local cache with Aden server fallback.
Provides offline resilience by caching credentials locally while
keeping them synchronized with the Aden server.
Usage:
from core.framework.credentials import CredentialStore
from core.framework.credentials.storage import EncryptedFileStorage
from core.framework.credentials.aden import (
AdenCredentialClient,
AdenClientConfig,
AdenSyncProvider,
AdenCachedStorage,
)
# Configure
client = AdenCredentialClient(AdenClientConfig(
base_url=os.environ["ADEN_API_URL"],
api_key=os.environ["ADEN_API_KEY"],
))
provider = AdenSyncProvider(client=client)
# Create cached storage
storage = AdenCachedStorage(
local_storage=EncryptedFileStorage(),
aden_provider=provider,
cache_ttl_seconds=300, # Re-check Aden every 5 minutes
)
# Create store
store = CredentialStore(
storage=storage,
providers=[provider],
auto_refresh=True,
)
# Credentials automatically fetched from Aden on first access
# Cached locally for 5 minutes
# Falls back to cache if Aden is unreachable
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from ..storage import CredentialStorage
if TYPE_CHECKING:
from ..models import CredentialObject
from .provider import AdenSyncProvider
logger = logging.getLogger(__name__)
class AdenCachedStorage(CredentialStorage):
"""
Storage with local cache and Aden server fallback.
This storage provides:
- **Reads**: Try local cache first, fallback to Aden if stale/missing
- **Writes**: Always write to local cache
- **Offline resilience**: Uses cached credentials when Aden is unreachable
The cache TTL determines how long to trust local credentials before
checking with the Aden server for updates. This balances:
- Performance (fewer network calls)
- Freshness (tokens stay current)
- Resilience (works during brief outages)
Usage:
storage = AdenCachedStorage(
local_storage=EncryptedFileStorage(),
aden_provider=provider,
cache_ttl_seconds=300, # 5 minutes
)
store = CredentialStore(
storage=storage,
providers=[provider],
)
# First access fetches from Aden
# Subsequent accesses use cache until TTL expires
token = store.get_key("hubspot", "access_token")
"""
def __init__(
self,
local_storage: CredentialStorage,
aden_provider: AdenSyncProvider,
cache_ttl_seconds: int = 300,
prefer_local: bool = True,
):
"""
Initialize Aden-cached storage.
Args:
local_storage: Local storage backend for caching (e.g., EncryptedFileStorage).
aden_provider: Provider for fetching from Aden server.
cache_ttl_seconds: How long to trust local cache before checking Aden.
Default is 300 seconds (5 minutes).
prefer_local: If True, use local cache when available and fresh.
If False, always check Aden first.
"""
self._local = local_storage
self._aden_provider = aden_provider
self._cache_ttl = timedelta(seconds=cache_ttl_seconds)
self._prefer_local = prefer_local
self._cache_timestamps: dict[str, datetime] = {}
def save(self, credential: CredentialObject) -> None:
"""
Save credential to local cache.
Args:
credential: The credential to save.
"""
self._local.save(credential)
self._cache_timestamps[credential.id] = datetime.now(UTC)
logger.debug(f"Cached credential '{credential.id}'")
def load(self, credential_id: str) -> CredentialObject | None:
"""
Load credential from cache, with Aden fallback.
The loading strategy depends on the `prefer_local` setting:
If prefer_local=True (default):
1. Check if local cache exists and is fresh (within TTL)
2. If fresh, return cached credential
3. If stale or missing, fetch from Aden
4. Update local cache with Aden response
5. If Aden fails, fall back to stale cache
If prefer_local=False:
1. Always try to fetch from Aden first
2. Update local cache with response
3. Fall back to local cache only if Aden fails
Args:
credential_id: The credential identifier.
Returns:
CredentialObject if found, None otherwise.
"""
local_cred = self._local.load(credential_id)
# If we prefer local and have a fresh cache, use it
if self._prefer_local and local_cred and self._is_cache_fresh(credential_id):
logger.debug(f"Using cached credential '{credential_id}'")
return local_cred
# Try to fetch from Aden
try:
aden_cred = self._aden_provider.fetch_from_aden(credential_id)
if aden_cred:
# Update local cache
self.save(aden_cred)
logger.debug(f"Fetched credential '{credential_id}' from Aden")
return aden_cred
except Exception as e:
logger.warning(f"Failed to fetch '{credential_id}' from Aden: {e}")
# Fall back to local cache if Aden fails
if local_cred:
logger.info(f"Using stale cached credential '{credential_id}'")
return local_cred
# Return local credential if it exists (may be None)
return local_cred
def delete(self, credential_id: str) -> bool:
"""
Delete credential from local cache.
Note: This does NOT delete the credential from the Aden server.
It only removes the local cache entry.
Args:
credential_id: The credential identifier.
Returns:
True if credential existed and was deleted.
"""
self._cache_timestamps.pop(credential_id, None)
return self._local.delete(credential_id)
def list_all(self) -> list[str]:
"""
List credentials from local cache.
Returns:
List of credential IDs in local cache.
"""
return self._local.list_all()
def exists(self, credential_id: str) -> bool:
"""
Check if credential exists in local cache.
Args:
credential_id: The credential identifier.
Returns:
True if credential exists locally.
"""
return self._local.exists(credential_id)
def _is_cache_fresh(self, credential_id: str) -> bool:
"""
Check if local cache is still fresh (within TTL).
Args:
credential_id: The credential identifier.
Returns:
True if cache is fresh, False if stale or not cached.
"""
cached_at = self._cache_timestamps.get(credential_id)
if cached_at is None:
return False
return datetime.now(UTC) - cached_at < self._cache_ttl
def invalidate_cache(self, credential_id: str) -> None:
"""
Invalidate cache for a specific credential.
The next load() call will fetch from Aden regardless of TTL.
Args:
credential_id: The credential identifier.
"""
self._cache_timestamps.pop(credential_id, None)
logger.debug(f"Invalidated cache for '{credential_id}'")
def invalidate_all(self) -> None:
"""Invalidate all cache entries."""
self._cache_timestamps.clear()
logger.debug("Invalidated all cache entries")
def sync_all_from_aden(self) -> int:
"""
Sync all credentials from Aden server to local cache.
Fetches the list of available integrations from Aden and
updates the local cache with current tokens.
Returns:
Number of credentials synced.
"""
synced = 0
try:
integrations = self._aden_provider._client.list_integrations()
for info in integrations:
if info.status != "active":
logger.warning(
f"Skipping integration '{info.integration_id}': status={info.status}"
)
continue
try:
cred = self._aden_provider.fetch_from_aden(info.integration_id)
if cred:
self.save(cred)
synced += 1
logger.info(f"Synced credential '{info.integration_id}' from Aden")
except Exception as e:
logger.warning(f"Failed to sync '{info.integration_id}': {e}")
except Exception as e:
logger.error(f"Failed to list integrations from Aden: {e}")
return synced
def get_cache_info(self) -> dict[str, dict]:
"""
Get cache status information for all credentials.
Returns:
Dict mapping credential_id to cache info (cached_at, is_fresh, ttl_remaining).
"""
now = datetime.now(UTC)
info = {}
for cred_id in self.list_all():
cached_at = self._cache_timestamps.get(cred_id)
if cached_at:
ttl_remaining = (cached_at + self._cache_ttl - now).total_seconds()
info[cred_id] = {
"cached_at": cached_at.isoformat(),
"is_fresh": ttl_remaining > 0,
"ttl_remaining_seconds": max(0, ttl_remaining),
}
else:
info[cred_id] = {
"cached_at": None,
"is_fresh": False,
"ttl_remaining_seconds": 0,
}
return info
@@ -0,0 +1 @@
"""Tests for Aden credential sync components."""
@@ -0,0 +1,670 @@
"""
Tests for Aden credential sync components.
Tests cover:
- AdenCredentialClient: HTTP client for Aden API
- AdenSyncProvider: Provider that syncs with Aden
- AdenCachedStorage: Storage with local cache + Aden fallback
"""
from datetime import UTC, datetime, timedelta
from unittest.mock import Mock
import pytest
from pydantic import SecretStr
from framework.credentials import (
CredentialKey,
CredentialObject,
CredentialStore,
CredentialType,
InMemoryStorage,
)
from framework.credentials.aden import (
AdenCachedStorage,
AdenClientConfig,
AdenClientError,
AdenCredentialClient,
AdenCredentialResponse,
AdenIntegrationInfo,
AdenRefreshError,
AdenSyncProvider,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def aden_config():
"""Create a test Aden client config."""
return AdenClientConfig(
base_url="https://api.test-aden.com",
api_key="test-api-key",
tenant_id="test-tenant",
timeout=5.0,
retry_attempts=2,
retry_delay=0.1,
)
@pytest.fixture
def mock_client(aden_config):
"""Create a mock Aden client."""
client = Mock(spec=AdenCredentialClient)
client.config = aden_config
return client
@pytest.fixture
def aden_response():
"""Create a sample Aden credential response."""
return AdenCredentialResponse(
integration_id="hubspot",
integration_type="hubspot",
access_token="test-access-token",
token_type="Bearer",
expires_at=datetime.now(UTC) + timedelta(hours=1),
scopes=["crm.objects.contacts.read", "crm.objects.contacts.write"],
metadata={"portal_id": "12345"},
)
@pytest.fixture
def provider(mock_client):
"""Create an AdenSyncProvider with mock client."""
return AdenSyncProvider(
client=mock_client,
provider_id="test_aden",
refresh_buffer_minutes=5,
report_usage=False,
)
@pytest.fixture
def local_storage():
"""Create an in-memory storage for testing."""
return InMemoryStorage()
@pytest.fixture
def cached_storage(local_storage, provider):
"""Create an AdenCachedStorage for testing."""
return AdenCachedStorage(
local_storage=local_storage,
aden_provider=provider,
cache_ttl_seconds=60,
prefer_local=True,
)
# =============================================================================
# AdenCredentialResponse Tests
# =============================================================================
class TestAdenCredentialResponse:
"""Tests for AdenCredentialResponse dataclass."""
def test_from_dict_basic(self):
"""Test creating response from dict."""
data = {
"integration_id": "github",
"integration_type": "github",
"access_token": "ghp_xxxxx",
}
response = AdenCredentialResponse.from_dict(data)
assert response.integration_id == "github"
assert response.integration_type == "github"
assert response.access_token == "ghp_xxxxx"
assert response.token_type == "Bearer"
assert response.expires_at is None
assert response.scopes == []
def test_from_dict_full(self):
"""Test creating response with all fields."""
data = {
"integration_id": "hubspot",
"integration_type": "hubspot",
"access_token": "token123",
"token_type": "Bearer",
"expires_at": "2026-01-28T15:30:00Z",
"scopes": ["read", "write"],
"metadata": {"key": "value"},
}
response = AdenCredentialResponse.from_dict(data)
assert response.integration_id == "hubspot"
assert response.access_token == "token123"
assert response.expires_at is not None
assert response.scopes == ["read", "write"]
assert response.metadata == {"key": "value"}
class TestAdenIntegrationInfo:
"""Tests for AdenIntegrationInfo dataclass."""
def test_from_dict(self):
"""Test creating integration info from dict."""
data = {
"integration_id": "slack",
"integration_type": "slack",
"status": "active",
"expires_at": "2026-02-01T00:00:00Z",
}
info = AdenIntegrationInfo.from_dict(data)
assert info.integration_id == "slack"
assert info.integration_type == "slack"
assert info.status == "active"
assert info.expires_at is not None
# =============================================================================
# AdenSyncProvider Tests
# =============================================================================
class TestAdenSyncProvider:
"""Tests for AdenSyncProvider."""
def test_provider_id(self, provider):
"""Test provider ID."""
assert provider.provider_id == "test_aden"
def test_supported_types(self, provider):
"""Test supported credential types."""
assert CredentialType.OAUTH2 in provider.supported_types
assert CredentialType.BEARER_TOKEN in provider.supported_types
def test_can_handle_oauth2(self, provider):
"""Test can_handle returns True for OAUTH2 credentials with matching provider_id."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={},
provider_id="test_aden",
)
assert provider.can_handle(cred) is True
def test_can_handle_aden_managed(self, provider):
"""Test can_handle returns True for Aden-managed credentials."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={
"_aden_managed": CredentialKey(
name="_aden_managed",
value=SecretStr("true"),
)
},
)
assert provider.can_handle(cred) is True
def test_can_handle_wrong_type(self, provider):
"""Test can_handle returns False for unsupported types."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.API_KEY,
keys={},
)
assert provider.can_handle(cred) is False
def test_refresh_success(self, provider, mock_client, aden_response):
"""Test successful credential refresh."""
mock_client.request_refresh.return_value = aden_response
cred = CredentialObject(
id="hubspot",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("old-token"),
)
},
provider_id="test_aden",
)
refreshed = provider.refresh(cred)
assert refreshed.keys["access_token"].value.get_secret_value() == "test-access-token"
assert refreshed.keys["_aden_managed"].value.get_secret_value() == "true"
assert refreshed.last_refreshed is not None
mock_client.request_refresh.assert_called_once_with("hubspot")
def test_refresh_requires_reauth(self, provider, mock_client):
"""Test refresh that requires re-authorization."""
mock_client.request_refresh.side_effect = AdenRefreshError(
"Token revoked",
requires_reauthorization=True,
reauthorization_url="https://aden.com/reauth",
)
cred = CredentialObject(
id="hubspot",
credential_type=CredentialType.OAUTH2,
keys={},
)
from framework.credentials import CredentialRefreshError
with pytest.raises(CredentialRefreshError) as exc_info:
provider.refresh(cred)
assert "re-authorization" in str(exc_info.value).lower()
def test_refresh_aden_unavailable_cached_valid(self, provider, mock_client):
"""Test refresh falls back to cache when Aden is unavailable and token is valid."""
mock_client.request_refresh.side_effect = AdenClientError("Connection failed")
# Token expires in 1 hour - still valid
future = datetime.now(UTC) + timedelta(hours=1)
cred = CredentialObject(
id="hubspot",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("cached-token"),
expires_at=future,
)
},
)
# Should return the cached credential instead of failing
result = provider.refresh(cred)
assert result.keys["access_token"].value.get_secret_value() == "cached-token"
def test_should_refresh_expired(self, provider):
"""Test should_refresh returns True for expired token."""
past = datetime.now(UTC) - timedelta(hours=1)
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("token"),
expires_at=past,
)
},
)
assert provider.should_refresh(cred) is True
def test_should_refresh_within_buffer(self, provider):
"""Test should_refresh returns True when within buffer."""
# Expires in 3 minutes (buffer is 5 minutes)
soon = datetime.now(UTC) + timedelta(minutes=3)
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("token"),
expires_at=soon,
)
},
)
assert provider.should_refresh(cred) is True
def test_should_refresh_still_valid(self, provider):
"""Test should_refresh returns False for valid token."""
future = datetime.now(UTC) + timedelta(hours=1)
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("token"),
expires_at=future,
)
},
)
assert provider.should_refresh(cred) is False
def test_fetch_from_aden(self, provider, mock_client, aden_response):
"""Test fetching credential from Aden."""
mock_client.get_credential.return_value = aden_response
cred = provider.fetch_from_aden("hubspot")
assert cred is not None
assert cred.id == "hubspot"
assert cred.keys["access_token"].value.get_secret_value() == "test-access-token"
assert cred.auto_refresh is True
def test_fetch_from_aden_not_found(self, provider, mock_client):
"""Test fetch returns None when not found."""
mock_client.get_credential.return_value = None
cred = provider.fetch_from_aden("nonexistent")
assert cred is None
def test_sync_all(self, provider, mock_client, aden_response):
"""Test syncing all credentials."""
mock_client.list_integrations.return_value = [
AdenIntegrationInfo(
integration_id="hubspot",
integration_type="hubspot",
status="active",
),
AdenIntegrationInfo(
integration_id="github",
integration_type="github",
status="requires_reauth", # Should be skipped
),
]
mock_client.get_credential.return_value = aden_response
store = CredentialStore(storage=InMemoryStorage())
synced = provider.sync_all(store)
assert synced == 1 # Only active one was synced
assert store.get_credential("hubspot") is not None
def test_validate_via_aden(self, provider, mock_client):
"""Test validation via Aden introspection."""
mock_client.validate_token.return_value = {"valid": True}
cred = CredentialObject(
id="hubspot",
credential_type=CredentialType.OAUTH2,
keys={},
)
assert provider.validate(cred) is True
def test_validate_fallback_to_local(self, provider, mock_client):
"""Test validation falls back to local check when Aden fails."""
mock_client.validate_token.side_effect = AdenClientError("Failed")
future = datetime.now(UTC) + timedelta(hours=1)
cred = CredentialObject(
id="hubspot",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("token"),
expires_at=future,
)
},
)
assert provider.validate(cred) is True
# =============================================================================
# AdenCachedStorage Tests
# =============================================================================
class TestAdenCachedStorage:
"""Tests for AdenCachedStorage."""
def test_save_updates_cache_timestamp(self, cached_storage):
"""Test save updates cache timestamp."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("token"),
)
},
)
cached_storage.save(cred)
assert "test" in cached_storage._cache_timestamps
assert cached_storage.exists("test")
def test_load_from_fresh_cache(self, cached_storage, local_storage):
"""Test load returns cached credential when fresh."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("cached-token"),
)
},
)
# Save to both local storage and update timestamp
local_storage.save(cred)
cached_storage._cache_timestamps["test"] = datetime.now(UTC)
loaded = cached_storage.load("test")
assert loaded is not None
assert loaded.keys["access_token"].value.get_secret_value() == "cached-token"
def test_load_from_aden_when_stale(
self, cached_storage, local_storage, provider, mock_client, aden_response
):
"""Test load fetches from Aden when cache is stale."""
# Create stale cached credential
cred = CredentialObject(
id="hubspot",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("stale-token"),
)
},
)
local_storage.save(cred)
# Set cache timestamp to be stale (2 minutes ago, TTL is 60 seconds)
cached_storage._cache_timestamps["hubspot"] = datetime.now(UTC) - timedelta(minutes=2)
# Mock Aden response
mock_client.get_credential.return_value = aden_response
loaded = cached_storage.load("hubspot")
assert loaded is not None
assert loaded.keys["access_token"].value.get_secret_value() == "test-access-token"
def test_load_falls_back_to_stale_when_aden_fails(
self, cached_storage, local_storage, provider, mock_client
):
"""Test load falls back to stale cache when Aden fails."""
# Create stale cached credential
cred = CredentialObject(
id="hubspot",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("stale-token"),
)
},
)
local_storage.save(cred)
cached_storage._cache_timestamps["hubspot"] = datetime.now(UTC) - timedelta(minutes=2)
# Aden fails
mock_client.get_credential.side_effect = AdenClientError("Connection failed")
loaded = cached_storage.load("hubspot")
assert loaded is not None
assert loaded.keys["access_token"].value.get_secret_value() == "stale-token"
def test_delete_removes_cache_timestamp(self, cached_storage, local_storage):
"""Test delete removes cache timestamp."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={},
)
cached_storage.save(cred)
assert "test" in cached_storage._cache_timestamps
cached_storage.delete("test")
assert "test" not in cached_storage._cache_timestamps
assert not cached_storage.exists("test")
def test_invalidate_cache(self, cached_storage, local_storage):
"""Test invalidate_cache removes timestamp."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.OAUTH2,
keys={},
)
cached_storage.save(cred)
cached_storage.invalidate_cache("test")
assert "test" not in cached_storage._cache_timestamps
# Credential still exists in local storage
assert local_storage.exists("test")
def test_invalidate_all(self, cached_storage):
"""Test invalidate_all clears all timestamps."""
for i in range(3):
cached_storage._cache_timestamps[f"test_{i}"] = datetime.now(UTC)
cached_storage.invalidate_all()
assert len(cached_storage._cache_timestamps) == 0
def test_is_cache_fresh(self, cached_storage):
"""Test _is_cache_fresh logic."""
# Fresh cache
cached_storage._cache_timestamps["fresh"] = datetime.now(UTC)
assert cached_storage._is_cache_fresh("fresh") is True
# Stale cache
cached_storage._cache_timestamps["stale"] = datetime.now(UTC) - timedelta(minutes=5)
assert cached_storage._is_cache_fresh("stale") is False
# No cache
assert cached_storage._is_cache_fresh("nonexistent") is False
def test_get_cache_info(self, cached_storage, local_storage):
"""Test get_cache_info returns status for all credentials."""
# Add some credentials
for name in ["fresh", "stale"]:
cred = CredentialObject(
id=name,
credential_type=CredentialType.OAUTH2,
keys={},
)
local_storage.save(cred)
cached_storage._cache_timestamps["fresh"] = datetime.now(UTC)
cached_storage._cache_timestamps["stale"] = datetime.now(UTC) - timedelta(minutes=5)
info = cached_storage.get_cache_info()
assert "fresh" in info
assert info["fresh"]["is_fresh"] is True
assert info["fresh"]["ttl_remaining_seconds"] > 0
assert "stale" in info
assert info["stale"]["is_fresh"] is False
assert info["stale"]["ttl_remaining_seconds"] == 0
# =============================================================================
# Integration Tests
# =============================================================================
class TestAdenIntegration:
"""Integration tests for Aden sync components."""
def test_full_workflow(self, mock_client, aden_response):
"""Test full workflow: sync, get, refresh."""
# Setup
mock_client.list_integrations.return_value = [
AdenIntegrationInfo(
integration_id="hubspot",
integration_type="hubspot",
status="active",
),
]
mock_client.get_credential.return_value = aden_response
mock_client.request_refresh.return_value = AdenCredentialResponse(
integration_id="hubspot",
integration_type="hubspot",
access_token="refreshed-token",
expires_at=datetime.now(UTC) + timedelta(hours=2),
scopes=[],
)
provider = AdenSyncProvider(client=mock_client)
storage = InMemoryStorage()
store = CredentialStore(
storage=storage,
providers=[provider],
auto_refresh=True,
)
# Initial sync
synced = provider.sync_all(store)
assert synced == 1
# Get credential
cred = store.get_credential("hubspot")
assert cred is not None
assert cred.keys["access_token"].value.get_secret_value() == "test-access-token"
# Simulate expiration
cred.keys["access_token"] = CredentialKey(
name="access_token",
value=SecretStr("test-access-token"),
expires_at=datetime.now(UTC) - timedelta(hours=1), # Expired
)
storage.save(cred)
# Refresh should be triggered
refreshed = provider.refresh(cred)
assert refreshed.keys["access_token"].value.get_secret_value() == "refreshed-token"
def test_cached_storage_with_store(self, mock_client, aden_response):
"""Test AdenCachedStorage with CredentialStore."""
mock_client.get_credential.return_value = aden_response
provider = AdenSyncProvider(client=mock_client)
local_storage = InMemoryStorage()
cached_storage = AdenCachedStorage(
local_storage=local_storage,
aden_provider=provider,
cache_ttl_seconds=300,
)
# First load fetches from Aden
cred = cached_storage.load("hubspot")
assert cred is not None
mock_client.get_credential.assert_called_once()
# Second load uses cache
mock_client.get_credential.reset_mock()
cred2 = cached_storage.load("hubspot")
assert cred2 is not None
mock_client.get_credential.assert_not_called()
+293
View File
@@ -0,0 +1,293 @@
"""
Core data models for the credential store.
This module defines the key-vault structure where credentials are objects
containing one or more keys (e.g., api_key, access_token, refresh_token).
"""
from __future__ import annotations
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, SecretStr
def _utc_now() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
class CredentialType(str, Enum):
"""Types of credentials the store can manage."""
API_KEY = "api_key"
"""Simple API key (e.g., Brave Search, OpenAI)"""
OAUTH2 = "oauth2"
"""OAuth2 with refresh token support"""
BASIC_AUTH = "basic_auth"
"""Username/password pair"""
BEARER_TOKEN = "bearer_token"
"""JWT or bearer token without refresh"""
CUSTOM = "custom"
"""User-defined credential type"""
class CredentialKey(BaseModel):
"""
A single key within a credential object.
Example: 'api_key' within a 'brave_search' credential
Attributes:
name: Key name (e.g., 'api_key', 'access_token')
value: Secret value (SecretStr prevents accidental logging)
expires_at: Optional expiration time
metadata: Additional key-specific metadata
"""
name: str
value: SecretStr
expires_at: datetime | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
model_config = {"extra": "allow"}
@property
def is_expired(self) -> bool:
"""Check if this key has expired."""
if self.expires_at is None:
return False
return datetime.now(UTC) >= self.expires_at
def get_secret_value(self) -> str:
"""Get the actual secret value (use sparingly)."""
return self.value.get_secret_value()
class CredentialObject(BaseModel):
"""
A credential object containing one or more keys.
This is the key-vault structure where each credential can have
multiple keys (e.g., access_token, refresh_token, expires_at).
Example:
CredentialObject(
id="github_oauth",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(name="access_token", value=SecretStr("ghp_xxx")),
"refresh_token": CredentialKey(name="refresh_token", value=SecretStr("ghr_xxx")),
},
provider_id="oauth2"
)
Attributes:
id: Unique identifier (e.g., 'brave_search', 'github_oauth')
credential_type: Type of credential (API_KEY, OAUTH2, etc.)
keys: Dictionary of key name to CredentialKey
provider_id: ID of provider responsible for lifecycle management
auto_refresh: Whether to automatically refresh when expired
"""
id: str = Field(description="Unique identifier (e.g., 'brave_search', 'github_oauth')")
credential_type: CredentialType = CredentialType.API_KEY
keys: dict[str, CredentialKey] = Field(default_factory=dict)
# Lifecycle management
provider_id: str | None = Field(
default=None,
description="ID of provider responsible for lifecycle (e.g., 'oauth2', 'static')",
)
last_refreshed: datetime | None = None
auto_refresh: bool = False
# Usage tracking
last_used: datetime | None = None
use_count: int = 0
# Metadata
description: str = ""
tags: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=_utc_now)
updated_at: datetime = Field(default_factory=_utc_now)
model_config = {"extra": "allow"}
def get_key(self, key_name: str) -> str | None:
"""
Get a specific key's value.
Args:
key_name: Name of the key to retrieve
Returns:
The key's secret value, or None if not found
"""
key = self.keys.get(key_name)
if key is None:
return None
return key.get_secret_value()
def set_key(
self,
key_name: str,
value: str,
expires_at: datetime | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""
Set or update a key.
Args:
key_name: Name of the key
value: Secret value
expires_at: Optional expiration time
metadata: Optional key-specific metadata
"""
self.keys[key_name] = CredentialKey(
name=key_name,
value=SecretStr(value),
expires_at=expires_at,
metadata=metadata or {},
)
self.updated_at = datetime.now(UTC)
def has_key(self, key_name: str) -> bool:
"""Check if a key exists."""
return key_name in self.keys
@property
def needs_refresh(self) -> bool:
"""Check if any key is expired or near expiration."""
for key in self.keys.values():
if key.is_expired:
return True
return False
@property
def is_valid(self) -> bool:
"""Check if credential has at least one non-expired key."""
if not self.keys:
return False
return not all(key.is_expired for key in self.keys.values())
def record_usage(self) -> None:
"""Record that this credential was used."""
self.last_used = datetime.now(UTC)
self.use_count += 1
def get_default_key(self) -> str | None:
"""
Get the default key value.
Priority: 'value' > 'api_key' > 'access_token' > first key
Returns:
The default key's value, or None if no keys exist
"""
for key_name in ["value", "api_key", "access_token"]:
if key_name in self.keys:
return self.get_key(key_name)
if self.keys:
first_key = next(iter(self.keys))
return self.get_key(first_key)
return None
class CredentialUsageSpec(BaseModel):
"""
Specification for how a tool uses credentials.
This implements the "bipartisan" model where the credential store
just stores values, and tools define how those values are used
in HTTP requests (headers, query params, body).
Example:
CredentialUsageSpec(
credential_id="brave_search",
required_keys=["api_key"],
headers={"X-Subscription-Token": "{{api_key}}"}
)
CredentialUsageSpec(
credential_id="github_oauth",
required_keys=["access_token"],
headers={"Authorization": "Bearer {{access_token}}"}
)
Attributes:
credential_id: ID of credential to use
required_keys: Keys that must be present
headers: Header templates with {{key}} placeholders
query_params: Query parameter templates
body_fields: Request body field templates
"""
credential_id: str = Field(description="ID of credential to use (e.g., 'brave_search')")
required_keys: list[str] = Field(default_factory=list, description="Keys that must be present")
# Injection templates (bipartisan model)
headers: dict[str, str] = Field(
default_factory=dict,
description="Header templates (e.g., {'Authorization': 'Bearer {{access_token}}'})",
)
query_params: dict[str, str] = Field(
default_factory=dict,
description="Query param templates (e.g., {'api_key': '{{api_key}}'})",
)
body_fields: dict[str, str] = Field(
default_factory=dict,
description="Request body field templates",
)
# Metadata
required: bool = True
description: str = ""
help_url: str = ""
model_config = {"extra": "allow"}
class CredentialError(Exception):
"""Base exception for credential-related errors."""
pass
class CredentialNotFoundError(CredentialError):
"""Raised when a referenced credential doesn't exist."""
pass
class CredentialKeyNotFoundError(CredentialError):
"""Raised when a referenced key doesn't exist in a credential."""
pass
class CredentialRefreshError(CredentialError):
"""Raised when credential refresh fails."""
pass
class CredentialValidationError(CredentialError):
"""Raised when credential validation fails."""
pass
class CredentialDecryptionError(CredentialError):
"""Raised when credential decryption fails."""
pass
@@ -0,0 +1,91 @@
"""
OAuth2 support for the credential store.
This module provides OAuth2 credential management with:
- Token types and configuration (OAuth2Token, OAuth2Config)
- Generic OAuth2 provider (BaseOAuth2Provider)
- Token lifecycle management (TokenLifecycleManager)
Quick Start:
from core.framework.credentials import CredentialStore
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config
# Configure OAuth2 provider
provider = BaseOAuth2Provider(OAuth2Config(
token_url="https://oauth2.example.com/token",
client_id="your-client-id",
client_secret="your-client-secret",
default_scopes=["read", "write"],
))
# Create store with OAuth2 provider
store = CredentialStore.with_encrypted_storage(
"/var/hive/credentials",
providers=[provider]
)
# Get token using client credentials
token = provider.client_credentials_grant()
# Save to store
from core.framework.credentials import CredentialObject, CredentialKey, CredentialType
from pydantic import SecretStr
store.save_credential(CredentialObject(
id="my_api",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr(token.access_token),
expires_at=token.expires_at,
),
"refresh_token": CredentialKey(
name="refresh_token",
value=SecretStr(token.refresh_token),
) if token.refresh_token else None,
},
provider_id="oauth2",
auto_refresh=True,
))
For advanced lifecycle management:
from core.framework.credentials.oauth2 import TokenLifecycleManager
manager = TokenLifecycleManager(
provider=provider,
credential_id="my_api",
store=store,
)
# Get valid token (auto-refreshes if needed)
token = manager.sync_get_valid_token()
headers = manager.get_request_headers()
"""
from .base_provider import BaseOAuth2Provider
from .lifecycle import TokenLifecycleManager, TokenRefreshResult
from .provider import (
OAuth2Config,
OAuth2Error,
OAuth2Token,
RefreshTokenInvalidError,
TokenExpiredError,
TokenPlacement,
)
__all__ = [
# Types
"OAuth2Token",
"OAuth2Config",
"TokenPlacement",
# Provider
"BaseOAuth2Provider",
# Lifecycle
"TokenLifecycleManager",
"TokenRefreshResult",
# Errors
"OAuth2Error",
"TokenExpiredError",
"RefreshTokenInvalidError",
]
@@ -0,0 +1,486 @@
"""
Base OAuth2 provider implementation.
This module provides a generic OAuth2 provider that works with standard
OAuth2 servers. OSS users can extend this class for custom providers.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from ..models import CredentialObject, CredentialRefreshError, CredentialType
from ..provider import CredentialProvider
from .provider import (
OAuth2Config,
OAuth2Error,
OAuth2Token,
TokenPlacement,
)
logger = logging.getLogger(__name__)
class BaseOAuth2Provider(CredentialProvider):
"""
Generic OAuth2 provider implementation.
Works with standard OAuth2 servers (RFC 6749). Override methods for
provider-specific behavior.
Supported grant types:
- Client Credentials: For server-to-server authentication
- Refresh Token: For refreshing expired access tokens
- Authorization Code: For user-authorized access (requires callback handling)
OSS users can extend this class for custom providers:
class GitHubOAuth2Provider(BaseOAuth2Provider):
def __init__(self, client_id: str, client_secret: str):
super().__init__(OAuth2Config(
token_url="https://github.com/login/oauth/access_token",
authorization_url="https://github.com/login/oauth/authorize",
client_id=client_id,
client_secret=client_secret,
default_scopes=["repo", "user"],
))
def exchange_code(self, code: str, redirect_uri: str, **kwargs) -> OAuth2Token:
# GitHub returns data as form-encoded by default
# Override to handle this
...
Example usage:
provider = BaseOAuth2Provider(OAuth2Config(
token_url="https://oauth2.example.com/token",
client_id="my-client-id",
client_secret="my-client-secret",
))
# Get token using client credentials
token = provider.client_credentials_grant()
# Refresh an expired token
new_token = provider.refresh_token(old_token.refresh_token)
"""
def __init__(self, config: OAuth2Config, provider_id: str = "oauth2"):
"""
Initialize the OAuth2 provider.
Args:
config: OAuth2 configuration
provider_id: Unique identifier for this provider instance
"""
self.config = config
self._provider_id = provider_id
self._client: Any | None = None
@property
def provider_id(self) -> str:
return self._provider_id
@property
def supported_types(self) -> list[CredentialType]:
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
def _get_client(self) -> Any:
"""Get or create HTTP client."""
if self._client is None:
try:
import httpx
self._client = httpx.Client(timeout=self.config.request_timeout)
except ImportError as e:
raise ImportError(
"OAuth2 provider requires 'httpx'. Install with: pip install httpx"
) from e
return self._client
def _close_client(self) -> None:
"""Close the HTTP client."""
if self._client is not None:
self._client.close()
self._client = None
def __del__(self) -> None:
"""Cleanup HTTP client on deletion."""
self._close_client()
# --- Grant Types ---
def get_authorization_url(
self,
state: str,
redirect_uri: str,
scopes: list[str] | None = None,
**kwargs: Any,
) -> str:
"""
Generate authorization URL for user consent (Authorization Code flow).
Args:
state: Anti-CSRF state parameter (should be random and verified)
redirect_uri: Callback URL to receive the authorization code
scopes: Requested scopes (defaults to config.default_scopes)
**kwargs: Additional provider-specific parameters
Returns:
URL to redirect user for authorization
Raises:
ValueError: If authorization_url is not configured
"""
if not self.config.authorization_url:
raise ValueError("authorization_url not configured for this provider")
params = {
"client_id": self.config.client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
"scope": " ".join(scopes or self.config.default_scopes),
**kwargs,
}
return f"{self.config.authorization_url}?{urlencode(params)}"
def exchange_code(
self,
code: str,
redirect_uri: str,
**kwargs: Any,
) -> OAuth2Token:
"""
Exchange authorization code for tokens (Authorization Code flow).
Args:
code: Authorization code from callback
redirect_uri: Same redirect_uri used in authorization request
**kwargs: Additional provider-specific parameters
Returns:
OAuth2Token with access_token and optional refresh_token
Raises:
OAuth2Error: If token exchange fails
"""
data = {
"grant_type": "authorization_code",
"client_id": self.config.client_id,
"client_secret": self.config.client_secret,
"code": code,
"redirect_uri": redirect_uri,
**self.config.extra_token_params,
**kwargs,
}
return self._token_request(data)
def client_credentials_grant(
self,
scopes: list[str] | None = None,
**kwargs: Any,
) -> OAuth2Token:
"""
Obtain token using client credentials (Client Credentials flow).
This is for server-to-server authentication where no user is involved.
Args:
scopes: Requested scopes (defaults to config.default_scopes)
**kwargs: Additional provider-specific parameters
Returns:
OAuth2Token (typically without refresh_token)
Raises:
OAuth2Error: If token request fails
"""
data = {
"grant_type": "client_credentials",
"client_id": self.config.client_id,
"client_secret": self.config.client_secret,
**self.config.extra_token_params,
**kwargs,
}
if scopes or self.config.default_scopes:
data["scope"] = " ".join(scopes or self.config.default_scopes)
return self._token_request(data)
def refresh_access_token(
self,
refresh_token: str,
scopes: list[str] | None = None,
**kwargs: Any,
) -> OAuth2Token:
"""
Refresh an expired access token (Refresh Token flow).
Args:
refresh_token: The refresh token
scopes: Scopes to request (defaults to original scopes)
**kwargs: Additional provider-specific parameters
Returns:
New OAuth2Token (may include new refresh_token)
Raises:
OAuth2Error: If refresh fails
RefreshTokenInvalidError: If refresh token is revoked/invalid
"""
data = {
"grant_type": "refresh_token",
"client_id": self.config.client_id,
"client_secret": self.config.client_secret,
"refresh_token": refresh_token,
**self.config.extra_token_params,
**kwargs,
}
if scopes:
data["scope"] = " ".join(scopes)
return self._token_request(data)
def revoke_token(
self,
token: str,
token_type_hint: str = "access_token",
) -> bool:
"""
Revoke a token (RFC 7009).
Args:
token: The token to revoke
token_type_hint: "access_token" or "refresh_token"
Returns:
True if revocation succeeded
"""
if not self.config.revocation_url:
logger.warning("revocation_url not configured, cannot revoke token")
return False
try:
client = self._get_client()
response = client.post(
self.config.revocation_url,
data={
"token": token,
"token_type_hint": token_type_hint,
"client_id": self.config.client_id,
"client_secret": self.config.client_secret,
},
headers={"Accept": "application/json", **self.config.extra_headers},
)
# RFC 7009: 200 indicates success (even if token was already invalid)
return response.status_code == 200
except Exception as e:
logger.error(f"Token revocation failed: {e}")
return False
# --- CredentialProvider Interface ---
def refresh(self, credential: CredentialObject) -> CredentialObject:
"""
Refresh a credential using its refresh token.
Implements CredentialProvider.refresh().
Args:
credential: The credential to refresh
Returns:
Updated credential with new access_token
Raises:
CredentialRefreshError: If refresh fails
"""
refresh_tok = credential.get_key("refresh_token")
if not refresh_tok:
raise CredentialRefreshError(f"Credential '{credential.id}' has no refresh_token")
try:
new_token = self.refresh_access_token(refresh_tok)
except OAuth2Error as e:
if e.error == "invalid_grant":
raise CredentialRefreshError(
f"Refresh token for '{credential.id}' is invalid or revoked. "
"Re-authorization required."
) from e
raise CredentialRefreshError(f"Failed to refresh '{credential.id}': {e}") from e
# Update credential
credential.set_key("access_token", new_token.access_token, expires_at=new_token.expires_at)
# Update refresh token if a new one was issued
if new_token.refresh_token and new_token.refresh_token != refresh_tok:
credential.set_key("refresh_token", new_token.refresh_token)
credential.last_refreshed = datetime.now(UTC)
logger.info(f"Refreshed OAuth2 credential '{credential.id}'")
return credential
def validate(self, credential: CredentialObject) -> bool:
"""
Validate that credential has a valid (non-expired) access_token.
Args:
credential: The credential to validate
Returns:
True if credential has valid access_token
"""
access_key = credential.keys.get("access_token")
if access_key is None:
return False
return not access_key.is_expired
def should_refresh(self, credential: CredentialObject) -> bool:
"""
Check if credential should be refreshed.
Returns True if access_token is expired or within 5 minutes of expiry.
"""
access_key = credential.keys.get("access_token")
if access_key is None:
return False
if access_key.expires_at is None:
return False
buffer = timedelta(minutes=5)
return datetime.now(UTC) >= (access_key.expires_at - buffer)
def revoke(self, credential: CredentialObject) -> bool:
"""
Revoke all tokens in a credential.
Args:
credential: The credential to revoke
Returns:
True if all revocations succeeded
"""
success = True
# Revoke access token
access_token = credential.get_key("access_token")
if access_token:
if not self.revoke_token(access_token, "access_token"):
success = False
# Revoke refresh token
refresh_token = credential.get_key("refresh_token")
if refresh_token:
if not self.revoke_token(refresh_token, "refresh_token"):
success = False
return success
# --- Token Request Helpers ---
def _token_request(self, data: dict[str, Any]) -> OAuth2Token:
"""
Make a token request to the OAuth2 server.
Args:
data: Form data for the token request
Returns:
OAuth2Token from the response
Raises:
OAuth2Error: If request fails or returns an error
"""
client = self._get_client()
headers = {
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
**self.config.extra_headers,
}
response = client.post(self.config.token_url, data=data, headers=headers)
# Parse response
content_type = response.headers.get("content-type", "")
if "application/json" in content_type:
response_data = response.json()
else:
# Some providers (like GitHub) may return form-encoded
response_data = self._parse_form_response(response.text)
# Check for error
if response.status_code != 200 or "error" in response_data:
error = response_data.get("error", "unknown_error")
description = response_data.get("error_description", response.text)
raise OAuth2Error(
error=error, description=description, status_code=response.status_code
)
return OAuth2Token.from_token_response(response_data)
def _parse_form_response(self, text: str) -> dict[str, str]:
"""Parse form-encoded response (some providers use this instead of JSON)."""
from urllib.parse import parse_qs
parsed = parse_qs(text)
return {k: v[0] if len(v) == 1 else v for k, v in parsed.items()}
# --- Token Formatting for Requests ---
def format_for_request(self, token: OAuth2Token) -> dict[str, Any]:
"""
Format token for use in HTTP requests (bipartisan model).
Args:
token: The OAuth2 token
Returns:
Dict with 'headers', 'params', or 'data' keys as appropriate
"""
placement = self.config.token_placement
if placement == TokenPlacement.HEADER_BEARER:
return {"headers": {"Authorization": f"{token.token_type} {token.access_token}"}}
elif placement == TokenPlacement.HEADER_CUSTOM:
header_name = self.config.custom_header_name or "X-Access-Token"
return {"headers": {header_name: token.access_token}}
elif placement == TokenPlacement.QUERY_PARAM:
return {"params": {self.config.query_param_name: token.access_token}}
elif placement == TokenPlacement.BODY_PARAM:
return {"data": {"access_token": token.access_token}}
return {}
def format_credential_for_request(self, credential: CredentialObject) -> dict[str, Any]:
"""
Format a credential for use in HTTP requests.
Args:
credential: The credential containing access_token
Returns:
Dict with 'headers', 'params', or 'data' keys as appropriate
"""
access_token = credential.get_key("access_token")
if not access_token:
return {}
token = OAuth2Token(
access_token=access_token,
token_type=credential.keys.get("token_type", "Bearer") or "Bearer",
)
return self.format_for_request(token)
@@ -0,0 +1,363 @@
"""
Token lifecycle management for OAuth2 credentials.
This module provides the TokenLifecycleManager which coordinates
automatic token refresh with the credential store.
"""
from __future__ import annotations
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from pydantic import SecretStr
from ..models import CredentialKey, CredentialObject, CredentialType
from .base_provider import BaseOAuth2Provider
from .provider import OAuth2Token
if TYPE_CHECKING:
from ..store import CredentialStore
logger = logging.getLogger(__name__)
@dataclass
class TokenRefreshResult:
"""Result of a token refresh operation."""
success: bool
token: OAuth2Token | None = None
error: str | None = None
needs_reauthorization: bool = False
class TokenLifecycleManager:
"""
Manages the complete lifecycle of OAuth2 tokens.
Responsibilities:
- Coordinate with CredentialStore for persistence
- Automatically refresh expired tokens
- Handle refresh failures gracefully
- Provide callbacks for monitoring
This class is useful when you need more control over token management
than the basic auto-refresh in CredentialStore provides.
Usage:
manager = TokenLifecycleManager(
provider=github_provider,
credential_id="github_oauth",
store=credential_store,
)
# Get valid token (auto-refreshes if needed)
token = await manager.get_valid_token()
# Use token
headers = provider.format_for_request(token)
Synchronous usage:
# For synchronous code, use sync_ methods
token = manager.sync_get_valid_token()
"""
def __init__(
self,
provider: BaseOAuth2Provider,
credential_id: str,
store: CredentialStore,
refresh_buffer_minutes: int = 5,
on_token_refreshed: Callable[[OAuth2Token], None] | None = None,
on_refresh_failed: Callable[[str], None] | None = None,
):
"""
Initialize the lifecycle manager.
Args:
provider: OAuth2 provider for token operations
credential_id: ID of the credential in the store
store: Credential store for persistence
refresh_buffer_minutes: Minutes before expiry to trigger refresh
on_token_refreshed: Callback when token is refreshed
on_refresh_failed: Callback when refresh fails
"""
self.provider = provider
self.credential_id = credential_id
self.store = store
self.refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
self.on_token_refreshed = on_token_refreshed
self.on_refresh_failed = on_refresh_failed
# In-memory cache for performance
self._cached_token: OAuth2Token | None = None
self._cache_time: datetime | None = None
# --- Async Token Access ---
async def get_valid_token(self) -> OAuth2Token | None:
"""
Get a valid access token, refreshing if necessary.
This is the main entry point for async code.
Returns:
Valid OAuth2Token or None if unavailable
"""
# Check cache first
if self._cached_token and not self._needs_refresh(self._cached_token):
return self._cached_token
# Load from store
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
if credential is None:
return None
# Convert to OAuth2Token
token = self._credential_to_token(credential)
if token is None:
return None
# Refresh if needed
if self._needs_refresh(token):
result = await self._async_refresh_token(credential)
if result.success and result.token:
token = result.token
elif result.needs_reauthorization:
logger.warning(f"Token for {self.credential_id} needs reauthorization")
return None
else:
# Use existing token if still technically valid
if token.is_expired:
return None
logger.warning(f"Refresh failed for {self.credential_id}, using existing token")
self._cached_token = token
self._cache_time = datetime.now(UTC)
return token
async def acquire_token_client_credentials(
self,
scopes: list[str] | None = None,
) -> OAuth2Token:
"""
Acquire a new token using client credentials flow.
For service-to-service authentication.
Args:
scopes: Scopes to request
Returns:
New OAuth2Token
"""
# Run in executor to avoid blocking
loop = asyncio.get_event_loop()
token = await loop.run_in_executor(
None, lambda: self.provider.client_credentials_grant(scopes=scopes)
)
self._save_token_to_store(token)
self._cached_token = token
return token
async def revoke(self) -> bool:
"""
Revoke tokens and clear from store.
Returns:
True if revocation succeeded
"""
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
if credential:
self.provider.revoke(credential)
self.store.delete_credential(self.credential_id)
self._cached_token = None
return True
# --- Synchronous Token Access ---
def sync_get_valid_token(self) -> OAuth2Token | None:
"""
Synchronous version of get_valid_token().
For use in synchronous code.
"""
# Check cache
if self._cached_token and not self._needs_refresh(self._cached_token):
return self._cached_token
# Load from store
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
if credential is None:
return None
token = self._credential_to_token(credential)
if token is None:
return None
# Refresh if needed
if self._needs_refresh(token):
result = self._sync_refresh_token(credential)
if result.success and result.token:
token = result.token
elif result.needs_reauthorization:
logger.warning(f"Token for {self.credential_id} needs reauthorization")
return None
else:
if token.is_expired:
return None
self._cached_token = token
self._cache_time = datetime.now(UTC)
return token
def sync_acquire_token_client_credentials(
self,
scopes: list[str] | None = None,
) -> OAuth2Token:
"""Synchronous version of acquire_token_client_credentials()."""
token = self.provider.client_credentials_grant(scopes=scopes)
self._save_token_to_store(token)
self._cached_token = token
return token
# --- Helper Methods ---
def _needs_refresh(self, token: OAuth2Token) -> bool:
"""Check if token needs refresh."""
if token.expires_at is None:
return False
return datetime.now(UTC) >= (token.expires_at - self.refresh_buffer)
def _credential_to_token(self, credential: CredentialObject) -> OAuth2Token | None:
"""Convert credential to OAuth2Token."""
access_token = credential.get_key("access_token")
if not access_token:
return None
expires_at = None
access_key = credential.keys.get("access_token")
if access_key:
expires_at = access_key.expires_at
return OAuth2Token(
access_token=access_token,
token_type="Bearer",
expires_at=expires_at,
refresh_token=credential.get_key("refresh_token"),
scope=credential.get_key("scope"),
)
def _save_token_to_store(self, token: OAuth2Token) -> None:
"""Save token to credential store."""
credential = CredentialObject(
id=self.credential_id,
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr(token.access_token),
expires_at=token.expires_at,
),
},
provider_id=self.provider.provider_id,
auto_refresh=True,
)
if token.refresh_token:
credential.keys["refresh_token"] = CredentialKey(
name="refresh_token",
value=SecretStr(token.refresh_token),
)
if token.scope:
credential.keys["scope"] = CredentialKey(
name="scope",
value=SecretStr(token.scope),
)
self.store.save_credential(credential)
async def _async_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
"""Async wrapper for token refresh."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, lambda: self._sync_refresh_token(credential))
def _sync_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
"""Synchronously refresh token."""
refresh_token = credential.get_key("refresh_token")
if not refresh_token:
return TokenRefreshResult(
success=False,
error="No refresh token available",
needs_reauthorization=True,
)
try:
new_token = self.provider.refresh_access_token(refresh_token)
# Save to store
self._save_token_to_store(new_token)
# Notify callback
if self.on_token_refreshed:
self.on_token_refreshed(new_token)
logger.info(f"Token refreshed for {self.credential_id}")
return TokenRefreshResult(success=True, token=new_token)
except Exception as e:
error_msg = str(e)
# Check for refresh token revocation
if "invalid_grant" in error_msg.lower():
return TokenRefreshResult(
success=False,
error=error_msg,
needs_reauthorization=True,
)
if self.on_refresh_failed:
self.on_refresh_failed(error_msg)
logger.error(f"Token refresh failed for {self.credential_id}: {e}")
return TokenRefreshResult(success=False, error=error_msg)
def invalidate_cache(self) -> None:
"""Clear cached token."""
self._cached_token = None
self._cache_time = None
# --- Convenience Methods ---
def get_request_headers(self) -> dict[str, str]:
"""
Get headers for HTTP request with current token.
Returns empty dict if no valid token.
"""
token = self.sync_get_valid_token()
if token is None:
return {}
result = self.provider.format_for_request(token)
return result.get("headers", {})
def get_request_kwargs(self) -> dict:
"""
Get kwargs for HTTP request (headers, params, etc.).
Returns empty dict if no valid token.
"""
token = self.sync_get_valid_token()
if token is None:
return {}
return self.provider.format_for_request(token)
@@ -0,0 +1,213 @@
"""
OAuth2 types and configuration.
This module defines the core OAuth2 data structures:
- OAuth2Token: Represents an access token with metadata
- OAuth2Config: Configuration for OAuth2 endpoints
- TokenPlacement: Where to place tokens in requests
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from enum import Enum
from typing import Any
class TokenPlacement(str, Enum):
"""Where to place the access token in HTTP requests."""
HEADER_BEARER = "header_bearer"
"""Authorization: Bearer <token> (most common)"""
HEADER_CUSTOM = "header_custom"
"""Custom header name (e.g., X-Access-Token)"""
QUERY_PARAM = "query_param"
"""Query parameter (e.g., ?access_token=<token>)"""
BODY_PARAM = "body_param"
"""Form body parameter"""
@dataclass
class OAuth2Token:
"""
Represents an OAuth2 token with metadata.
Attributes:
access_token: The access token string
token_type: Token type (usually "Bearer")
expires_at: When the token expires
refresh_token: Optional refresh token
scope: Granted scopes (space-separated)
raw_response: Original token response from server
"""
access_token: str
token_type: str = "Bearer"
expires_at: datetime | None = None
refresh_token: str | None = None
scope: str | None = None
raw_response: dict[str, Any] = field(default_factory=dict)
@property
def is_expired(self) -> bool:
"""
Check if token is expired.
Uses a 5-minute buffer to account for clock skew and
request latency.
"""
if self.expires_at is None:
return False
buffer = timedelta(minutes=5)
return datetime.now(UTC) >= (self.expires_at - buffer)
@property
def can_refresh(self) -> bool:
"""Check if token can be refreshed (has refresh_token)."""
return self.refresh_token is not None and self.refresh_token.strip() != ""
@property
def expires_in_seconds(self) -> int | None:
"""Get seconds until expiration, or None if no expiration."""
if self.expires_at is None:
return None
delta = self.expires_at - datetime.now(UTC)
return max(0, int(delta.total_seconds()))
@classmethod
def from_token_response(cls, data: dict[str, Any]) -> OAuth2Token:
"""
Create OAuth2Token from an OAuth2 token endpoint response.
Args:
data: Token response JSON (access_token, token_type, expires_in, etc.)
Returns:
OAuth2Token instance
"""
expires_at = None
if "expires_in" in data:
expires_at = datetime.now(UTC) + timedelta(seconds=data["expires_in"])
return cls(
access_token=data["access_token"],
token_type=data.get("token_type", "Bearer"),
expires_at=expires_at,
refresh_token=data.get("refresh_token"),
scope=data.get("scope"),
raw_response=data,
)
@dataclass
class OAuth2Config:
"""
Configuration for an OAuth2 provider.
This contains all the information needed to perform OAuth2 operations
for a specific provider (GitHub, Google, Salesforce, etc.).
Attributes:
token_url: URL for token endpoint (required)
authorization_url: URL for authorization endpoint (optional, for auth code flow)
revocation_url: URL for token revocation (optional)
introspection_url: URL for token introspection (optional)
client_id: OAuth2 client ID
client_secret: OAuth2 client secret
default_scopes: Default scopes to request
token_placement: How to include token in requests
custom_header_name: Header name when using HEADER_CUSTOM placement
query_param_name: Query param name when using QUERY_PARAM placement
extra_token_params: Additional parameters for token requests
request_timeout: Timeout for HTTP requests in seconds
Example:
config = OAuth2Config(
token_url="https://github.com/login/oauth/access_token",
authorization_url="https://github.com/login/oauth/authorize",
client_id="your-client-id",
client_secret="your-client-secret",
default_scopes=["repo", "user"],
)
"""
# Endpoints (only token_url is strictly required)
token_url: str
authorization_url: str | None = None
revocation_url: str | None = None
introspection_url: str | None = None
# Client credentials
client_id: str = ""
client_secret: str = ""
# Scopes
default_scopes: list[str] = field(default_factory=list)
# Token placement for API calls (bipartisan model)
token_placement: TokenPlacement = TokenPlacement.HEADER_BEARER
custom_header_name: str | None = None
query_param_name: str = "access_token"
# Request configuration
extra_token_params: dict[str, str] = field(default_factory=dict)
request_timeout: float = 30.0
# Additional headers for token requests
extra_headers: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate configuration."""
if not self.token_url:
raise ValueError("token_url is required")
if self.token_placement == TokenPlacement.HEADER_CUSTOM and not self.custom_header_name:
raise ValueError("custom_header_name is required when using HEADER_CUSTOM placement")
class OAuth2Error(Exception):
"""
OAuth2 protocol error.
Attributes:
error: OAuth2 error code (e.g., 'invalid_grant', 'invalid_client')
description: Human-readable error description
status_code: HTTP status code from the response
"""
def __init__(
self,
error: str,
description: str = "",
status_code: int = 0,
):
self.error = error
self.description = description
self.status_code = status_code
super().__init__(f"{error}: {description}" if description else error)
class TokenExpiredError(OAuth2Error):
"""Raised when a token has expired and cannot be used."""
def __init__(self, credential_id: str):
super().__init__(
error="token_expired",
description=f"Token for '{credential_id}' has expired",
)
self.credential_id = credential_id
class RefreshTokenInvalidError(OAuth2Error):
"""Raised when the refresh token is invalid or revoked."""
def __init__(self, credential_id: str, reason: str = ""):
description = f"Refresh token for '{credential_id}' is invalid"
if reason:
description += f": {reason}"
super().__init__(error="invalid_grant", description=description)
self.credential_id = credential_id
+283
View File
@@ -0,0 +1,283 @@
"""
Provider interface for credential lifecycle management.
Providers handle credential lifecycle operations:
- Refresh: Obtain new tokens when expired
- Validate: Check if credentials are still working
- Revoke: Invalidate credentials when no longer needed
OSS users can implement custom providers by subclassing CredentialProvider.
"""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from datetime import UTC, datetime, timedelta
from .models import CredentialObject, CredentialRefreshError, CredentialType
logger = logging.getLogger(__name__)
class CredentialProvider(ABC):
"""
Abstract base class for credential providers.
Providers handle credential lifecycle operations:
- refresh(): Obtain new tokens when expired
- validate(): Check if credentials are still working
- should_refresh(): Determine if a credential needs refresh
- revoke(): Invalidate credentials (optional)
Example custom provider:
class MyCustomProvider(CredentialProvider):
@property
def provider_id(self) -> str:
return "my_custom"
@property
def supported_types(self) -> List[CredentialType]:
return [CredentialType.CUSTOM]
def refresh(self, credential: CredentialObject) -> CredentialObject:
# Custom refresh logic
new_token = my_api.refresh(credential.get_key("api_key"))
credential.set_key("access_token", new_token)
return credential
def validate(self, credential: CredentialObject) -> bool:
token = credential.get_key("access_token")
return my_api.validate(token)
"""
@property
@abstractmethod
def provider_id(self) -> str:
"""
Unique identifier for this provider.
Examples: 'static', 'oauth2', 'my_custom_auth'
"""
pass
@property
@abstractmethod
def supported_types(self) -> list[CredentialType]:
"""
Credential types this provider can manage.
Returns:
List of CredentialType enums this provider supports
"""
pass
@abstractmethod
def refresh(self, credential: CredentialObject) -> CredentialObject:
"""
Refresh the credential (e.g., use refresh_token to get new access_token).
This method should:
1. Use existing credential data to obtain new values
2. Update the credential object with new values
3. Set appropriate expiration times
4. Update last_refreshed timestamp
Args:
credential: The credential to refresh
Returns:
Updated credential with new values
Raises:
CredentialRefreshError: If refresh fails
"""
pass
@abstractmethod
def validate(self, credential: CredentialObject) -> bool:
"""
Validate that a credential is still working.
This might involve:
- Checking expiration times
- Making a test API call
- Validating token signatures
Args:
credential: The credential to validate
Returns:
True if credential is valid, False otherwise
"""
pass
def should_refresh(self, credential: CredentialObject) -> bool:
"""
Determine if a credential should be refreshed.
Default implementation: refresh if any key is expired or within
5 minutes of expiry. Override for custom logic.
Args:
credential: The credential to check
Returns:
True if credential should be refreshed
"""
buffer = timedelta(minutes=5)
now = datetime.now(UTC)
for key in credential.keys.values():
if key.expires_at is not None:
if key.expires_at <= now + buffer:
return True
return False
def revoke(self, credential: CredentialObject) -> bool:
"""
Revoke a credential (optional operation).
Not all providers support revocation. The default implementation
logs a warning and returns False.
Args:
credential: The credential to revoke
Returns:
True if revocation succeeded, False otherwise
"""
logger.warning(f"Provider '{self.provider_id}' does not support revocation")
return False
def can_handle(self, credential: CredentialObject) -> bool:
"""
Check if this provider can handle a credential.
Args:
credential: The credential to check
Returns:
True if this provider can manage the credential
"""
return credential.credential_type in self.supported_types
class StaticProvider(CredentialProvider):
"""
Provider for static credentials that never need refresh.
Use for simple API keys that don't expire, such as:
- Brave Search API key
- OpenAI API key
- Basic auth credentials
Static credentials are always considered valid if they have at least one key.
"""
@property
def provider_id(self) -> str:
return "static"
@property
def supported_types(self) -> list[CredentialType]:
return [CredentialType.API_KEY, CredentialType.BASIC_AUTH, CredentialType.CUSTOM]
def refresh(self, credential: CredentialObject) -> CredentialObject:
"""
Static credentials don't need refresh.
Returns the credential unchanged.
"""
logger.debug(f"Static credential '{credential.id}' does not need refresh")
return credential
def validate(self, credential: CredentialObject) -> bool:
"""
Validate that credential has at least one key with a value.
For static credentials, we can't verify the key works without
making an API call, so we just check existence.
"""
if not credential.keys:
return False
# Check at least one key has a non-empty value
for key in credential.keys.values():
try:
value = key.get_secret_value()
if value and value.strip():
return True
except Exception:
continue
return False
def should_refresh(self, credential: CredentialObject) -> bool:
"""Static credentials never need refresh."""
return False
class BearerTokenProvider(CredentialProvider):
"""
Provider for bearer tokens without refresh capability.
Use for JWTs or tokens that:
- Have an expiration time
- Cannot be refreshed (no refresh token)
- Must be re-obtained when expired
This provider validates based on expiration time only.
"""
@property
def provider_id(self) -> str:
return "bearer_token"
@property
def supported_types(self) -> list[CredentialType]:
return [CredentialType.BEARER_TOKEN]
def refresh(self, credential: CredentialObject) -> CredentialObject:
"""
Bearer tokens without refresh capability cannot be refreshed.
Raises:
CredentialRefreshError: Always, as refresh is not supported
"""
raise CredentialRefreshError(
f"Bearer token '{credential.id}' cannot be refreshed. "
"Obtain a new token and save it to the credential store."
)
def validate(self, credential: CredentialObject) -> bool:
"""
Validate based on expiration time.
Returns True if token exists and is not expired.
"""
access_key = credential.keys.get("access_token") or credential.keys.get("token")
if access_key is None:
return False
# Check if expired
return not access_key.is_expired
def should_refresh(self, credential: CredentialObject) -> bool:
"""
Check if token is expired or near expiration.
Note: Even though this returns True for expired tokens,
refresh() will fail. This allows the store to know the
credential needs attention.
"""
buffer = timedelta(minutes=5)
now = datetime.now(UTC)
for key_name in ["access_token", "token"]:
key = credential.keys.get(key_name)
if key and key.expires_at:
if key.expires_at <= now + buffer:
return True
return False
+516
View File
@@ -0,0 +1,516 @@
"""
Storage backends for the credential store.
This module provides abstract and concrete storage implementations:
- CredentialStorage: Abstract base class
- EncryptedFileStorage: Fernet-encrypted JSON files (default for production)
- EnvVarStorage: Environment variable reading (backward compatibility)
- InMemoryStorage: For testing
"""
from __future__ import annotations
import json
import logging
import os
from abc import ABC, abstractmethod
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from pydantic import SecretStr
from .models import CredentialDecryptionError, CredentialKey, CredentialObject, CredentialType
logger = logging.getLogger(__name__)
class CredentialStorage(ABC):
"""
Abstract storage backend for credentials.
Implementations must provide save, load, delete, list_all, and exists methods.
All implementations should handle serialization of SecretStr values securely.
"""
@abstractmethod
def save(self, credential: CredentialObject) -> None:
"""
Save a credential to storage.
Args:
credential: The credential object to save
"""
pass
@abstractmethod
def load(self, credential_id: str) -> CredentialObject | None:
"""
Load a credential from storage.
Args:
credential_id: The ID of the credential to load
Returns:
CredentialObject if found, None otherwise
"""
pass
@abstractmethod
def delete(self, credential_id: str) -> bool:
"""
Delete a credential from storage.
Args:
credential_id: The ID of the credential to delete
Returns:
True if the credential existed and was deleted, False otherwise
"""
pass
@abstractmethod
def list_all(self) -> list[str]:
"""
List all credential IDs in storage.
Returns:
List of credential IDs
"""
pass
@abstractmethod
def exists(self, credential_id: str) -> bool:
"""
Check if a credential exists in storage.
Args:
credential_id: The ID to check
Returns:
True if credential exists, False otherwise
"""
pass
class EncryptedFileStorage(CredentialStorage):
"""
Encrypted file-based credential storage.
Uses Fernet symmetric encryption (AES-128-CBC + HMAC) for at-rest encryption.
Each credential is stored as a separate encrypted JSON file.
Directory structure:
{base_path}/
credentials/
{credential_id}.enc # Encrypted credential JSON
metadata/
index.json # Index of all credentials (unencrypted)
The encryption key is read from the HIVE_CREDENTIAL_KEY environment variable.
If not set, a new key is generated (and must be persisted for data recovery).
Example:
storage = EncryptedFileStorage("/var/hive/credentials")
storage.save(credential)
credential = storage.load("brave_search")
"""
def __init__(
self,
base_path: str | Path,
encryption_key: bytes | None = None,
key_env_var: str = "HIVE_CREDENTIAL_KEY",
):
"""
Initialize encrypted storage.
Args:
base_path: Directory for credential files
encryption_key: 32-byte Fernet key. If None, reads from env var.
key_env_var: Environment variable containing encryption key
"""
try:
from cryptography.fernet import Fernet
except ImportError as e:
raise ImportError(
"Encrypted storage requires 'cryptography'. Install with: pip install cryptography"
) from e
self.base_path = Path(base_path)
self._ensure_dirs()
self._key_env_var = key_env_var
# Get or generate encryption key
if encryption_key:
self._key = encryption_key
else:
key_str = os.environ.get(key_env_var)
if key_str:
self._key = key_str.encode()
else:
# Generate new key
self._key = Fernet.generate_key()
logger.warning(
f"Generated new encryption key. To persist credentials across restarts, "
f"set {key_env_var}={self._key.decode()}"
)
self._fernet = Fernet(self._key)
def _ensure_dirs(self) -> None:
"""Create directory structure."""
(self.base_path / "credentials").mkdir(parents=True, exist_ok=True)
(self.base_path / "metadata").mkdir(parents=True, exist_ok=True)
def _cred_path(self, credential_id: str) -> Path:
"""Get the file path for a credential."""
# Sanitize credential_id to prevent path traversal
safe_id = credential_id.replace("/", "_").replace("\\", "_").replace("..", "_")
return self.base_path / "credentials" / f"{safe_id}.enc"
def save(self, credential: CredentialObject) -> None:
"""Encrypt and save credential."""
# Serialize credential
data = self._serialize_credential(credential)
json_bytes = json.dumps(data, default=str).encode()
# Encrypt
encrypted = self._fernet.encrypt(json_bytes)
# Write to file
cred_path = self._cred_path(credential.id)
with open(cred_path, "wb") as f:
f.write(encrypted)
# Update index
self._update_index(credential.id, "save", credential.credential_type.value)
logger.debug(f"Saved encrypted credential '{credential.id}'")
def load(self, credential_id: str) -> CredentialObject | None:
"""Load and decrypt credential."""
cred_path = self._cred_path(credential_id)
if not cred_path.exists():
return None
# Read encrypted data
with open(cred_path, "rb") as f:
encrypted = f.read()
# Decrypt
try:
json_bytes = self._fernet.decrypt(encrypted)
data = json.loads(json_bytes.decode())
except Exception as e:
raise CredentialDecryptionError(
f"Failed to decrypt credential '{credential_id}': {e}"
) from e
# Deserialize
return self._deserialize_credential(data)
def delete(self, credential_id: str) -> bool:
"""Delete a credential file."""
cred_path = self._cred_path(credential_id)
if cred_path.exists():
cred_path.unlink()
self._update_index(credential_id, "delete")
logger.debug(f"Deleted credential '{credential_id}'")
return True
return False
def list_all(self) -> list[str]:
"""List all credential IDs."""
index_path = self.base_path / "metadata" / "index.json"
if not index_path.exists():
return []
with open(index_path) as f:
index = json.load(f)
return list(index.get("credentials", {}).keys())
def exists(self, credential_id: str) -> bool:
"""Check if credential exists."""
return self._cred_path(credential_id).exists()
def _serialize_credential(self, credential: CredentialObject) -> dict[str, Any]:
"""Convert credential to JSON-serializable dict, extracting secret values."""
data = credential.model_dump(mode="json")
# Extract actual secret values from SecretStr
for key_name, key_data in data.get("keys", {}).items():
if "value" in key_data:
# SecretStr serializes as "**********", need actual value
actual_key = credential.keys.get(key_name)
if actual_key:
key_data["value"] = actual_key.get_secret_value()
return data
def _deserialize_credential(self, data: dict[str, Any]) -> CredentialObject:
"""Reconstruct credential from dict, wrapping values in SecretStr."""
# Convert plain values back to SecretStr
for key_data in data.get("keys", {}).values():
if "value" in key_data and isinstance(key_data["value"], str):
key_data["value"] = SecretStr(key_data["value"])
return CredentialObject.model_validate(data)
def _update_index(
self,
credential_id: str,
operation: str,
credential_type: str | None = None,
) -> None:
"""Update the metadata index."""
index_path = self.base_path / "metadata" / "index.json"
if index_path.exists():
with open(index_path) as f:
index = json.load(f)
else:
index = {"credentials": {}, "version": "1.0"}
if operation == "save":
index["credentials"][credential_id] = {
"updated_at": datetime.now(UTC).isoformat(),
"type": credential_type,
}
elif operation == "delete":
index["credentials"].pop(credential_id, None)
index["last_modified"] = datetime.now(UTC).isoformat()
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
class EnvVarStorage(CredentialStorage):
"""
Environment variable-based storage for backward compatibility.
Maps credential IDs to environment variable patterns.
Supports hot-reload from .env files using python-dotenv.
This storage is READ-ONLY - credentials cannot be saved at runtime.
Example:
storage = EnvVarStorage(
env_mapping={"brave_search": "BRAVE_SEARCH_API_KEY"},
dotenv_path=Path(".env")
)
credential = storage.load("brave_search")
"""
def __init__(
self,
env_mapping: dict[str, str] | None = None,
dotenv_path: Path | None = None,
):
"""
Initialize env var storage.
Args:
env_mapping: Map of credential_id -> env_var_name
e.g., {"brave_search": "BRAVE_SEARCH_API_KEY"}
If not provided, uses {CREDENTIAL_ID}_API_KEY pattern
dotenv_path: Path to .env file for hot-reload support
"""
self._env_mapping = env_mapping or {}
self._dotenv_path = dotenv_path or Path.cwd() / ".env"
def _get_env_var_name(self, credential_id: str) -> str:
"""Get the environment variable name for a credential."""
if credential_id in self._env_mapping:
return self._env_mapping[credential_id]
# Default pattern: CREDENTIAL_ID_API_KEY
return f"{credential_id.upper().replace('-', '_')}_API_KEY"
def _read_env_value(self, env_var: str) -> str | None:
"""Read value from env var or .env file."""
# Check os.environ first (takes precedence)
value = os.environ.get(env_var)
if value:
return value
# Fallback: read from .env file (hot-reload)
if self._dotenv_path.exists():
try:
from dotenv import dotenv_values
values = dotenv_values(self._dotenv_path)
return values.get(env_var)
except ImportError:
logger.debug("python-dotenv not installed, skipping .env file")
return None
return None
def save(self, credential: CredentialObject) -> None:
"""Cannot save to environment variables at runtime."""
raise NotImplementedError(
"EnvVarStorage is read-only. Set environment variables "
"externally or use EncryptedFileStorage."
)
def load(self, credential_id: str) -> CredentialObject | None:
"""Load credential from environment variable."""
env_var = self._get_env_var_name(credential_id)
value = self._read_env_value(env_var)
if not value:
return None
return CredentialObject(
id=credential_id,
credential_type=CredentialType.API_KEY,
keys={"api_key": CredentialKey(name="api_key", value=SecretStr(value))},
description=f"Loaded from {env_var}",
)
def delete(self, credential_id: str) -> bool:
"""Cannot delete environment variables at runtime."""
raise NotImplementedError(
"EnvVarStorage is read-only. Unset environment variables externally."
)
def list_all(self) -> list[str]:
"""List credentials that are available in environment."""
available = []
# Check mapped credentials
for cred_id in self._env_mapping.keys():
if self.exists(cred_id):
available.append(cred_id)
return available
def exists(self, credential_id: str) -> bool:
"""Check if credential is available in environment."""
env_var = self._get_env_var_name(credential_id)
return self._read_env_value(env_var) is not None
def add_mapping(self, credential_id: str, env_var: str) -> None:
"""
Add a credential ID to environment variable mapping.
Args:
credential_id: The credential identifier
env_var: The environment variable name
"""
self._env_mapping[credential_id] = env_var
class InMemoryStorage(CredentialStorage):
"""
In-memory storage for testing.
Credentials are stored in a dictionary and lost when the process exits.
Example:
storage = InMemoryStorage()
storage.save(credential)
credential = storage.load("test_cred")
"""
def __init__(self, initial_data: dict[str, CredentialObject] | None = None):
"""
Initialize in-memory storage.
Args:
initial_data: Optional dict of credential_id -> CredentialObject
"""
self._data: dict[str, CredentialObject] = initial_data or {}
def save(self, credential: CredentialObject) -> None:
"""Save credential to memory."""
self._data[credential.id] = credential
def load(self, credential_id: str) -> CredentialObject | None:
"""Load credential from memory."""
return self._data.get(credential_id)
def delete(self, credential_id: str) -> bool:
"""Delete credential from memory."""
if credential_id in self._data:
del self._data[credential_id]
return True
return False
def list_all(self) -> list[str]:
"""List all credential IDs."""
return list(self._data.keys())
def exists(self, credential_id: str) -> bool:
"""Check if credential exists."""
return credential_id in self._data
def clear(self) -> None:
"""Clear all credentials."""
self._data.clear()
class CompositeStorage(CredentialStorage):
"""
Composite storage that reads from multiple backends.
Useful for layering storages, e.g., encrypted file with env var fallback:
- Writes go to the primary storage
- Reads check primary first, then fallback storages
Example:
storage = CompositeStorage(
primary=EncryptedFileStorage("/var/hive/credentials"),
fallbacks=[EnvVarStorage({"brave_search": "BRAVE_SEARCH_API_KEY"})]
)
"""
def __init__(
self,
primary: CredentialStorage,
fallbacks: list[CredentialStorage] | None = None,
):
"""
Initialize composite storage.
Args:
primary: Primary storage for writes and first read attempt
fallbacks: List of fallback storages to check if primary doesn't have credential
"""
self._primary = primary
self._fallbacks = fallbacks or []
def save(self, credential: CredentialObject) -> None:
"""Save to primary storage."""
self._primary.save(credential)
def load(self, credential_id: str) -> CredentialObject | None:
"""Load from primary, then fallbacks."""
# Try primary first
credential = self._primary.load(credential_id)
if credential is not None:
return credential
# Try fallbacks
for fallback in self._fallbacks:
credential = fallback.load(credential_id)
if credential is not None:
return credential
return None
def delete(self, credential_id: str) -> bool:
"""Delete from primary storage only."""
return self._primary.delete(credential_id)
def list_all(self) -> list[str]:
"""List credentials from all storages."""
all_ids = set(self._primary.list_all())
for fallback in self._fallbacks:
all_ids.update(fallback.list_all())
return list(all_ids)
def exists(self, credential_id: str) -> bool:
"""Check if credential exists in any storage."""
if self._primary.exists(credential_id):
return True
return any(fallback.exists(credential_id) for fallback in self._fallbacks)
+708
View File
@@ -0,0 +1,708 @@
"""
Main credential store orchestrating storage, providers, and template resolution.
The CredentialStore is the primary interface for credential management, providing:
- Multi-backend storage (file, env, vault)
- Provider-based lifecycle management (refresh, validate)
- Template resolution for {{cred.key}} patterns
- Caching with TTL for performance
- Thread-safe operations
"""
from __future__ import annotations
import logging
import threading
from datetime import UTC, datetime
from typing import Any
from pydantic import SecretStr
from .models import (
CredentialKey,
CredentialObject,
CredentialRefreshError,
CredentialUsageSpec,
)
from .provider import CredentialProvider, StaticProvider
from .storage import CredentialStorage, EnvVarStorage, InMemoryStorage
from .template import TemplateResolver
logger = logging.getLogger(__name__)
class CredentialStore:
"""
Main credential store orchestrating storage, providers, and template resolution.
Features:
- Multi-backend storage (file, env, vault)
- Provider-based lifecycle management (refresh, validate)
- Template resolution for {{cred.key}} patterns
- Caching with TTL for performance
- Thread-safe operations
Usage:
# Basic usage
store = CredentialStore(
storage=EncryptedFileStorage("/path/to/creds"),
providers=[OAuth2Provider(), StaticProvider()]
)
# Get a credential
cred = store.get_credential("github_oauth")
# Resolve templates in headers
headers = store.resolve_headers({
"Authorization": "Bearer {{github_oauth.access_token}}"
})
# Register a tool's credential requirements
store.register_usage(CredentialUsageSpec(
credential_id="brave_search",
required_keys=["api_key"],
headers={"X-Subscription-Token": "{{brave_search.api_key}}"}
))
"""
def __init__(
self,
storage: CredentialStorage | None = None,
providers: list[CredentialProvider] | None = None,
cache_ttl_seconds: int = 300,
auto_refresh: bool = True,
):
"""
Initialize the credential store.
Args:
storage: Storage backend. Defaults to EnvVarStorage for compatibility.
providers: List of credential providers. Defaults to [StaticProvider()].
cache_ttl_seconds: How long to cache credentials in memory (default: 5 minutes).
auto_refresh: Whether to auto-refresh expired credentials on access.
"""
self._storage = storage or EnvVarStorage()
self._providers: dict[str, CredentialProvider] = {}
self._usage_specs: dict[str, CredentialUsageSpec] = {}
# Cache: credential_id -> (CredentialObject, cached_at)
self._cache: dict[str, tuple[CredentialObject, datetime]] = {}
self._cache_ttl = cache_ttl_seconds
self._lock = threading.RLock()
self._auto_refresh = auto_refresh
# Register providers
for provider in providers or [StaticProvider()]:
self.register_provider(provider)
# Template resolver
self._resolver = TemplateResolver(self)
# --- Provider Management ---
def register_provider(self, provider: CredentialProvider) -> None:
"""
Register a credential provider.
Args:
provider: The provider to register
"""
self._providers[provider.provider_id] = provider
logger.debug(f"Registered credential provider: {provider.provider_id}")
def get_provider(self, provider_id: str) -> CredentialProvider | None:
"""
Get a provider by ID.
Args:
provider_id: The provider identifier
Returns:
The provider if found, None otherwise
"""
return self._providers.get(provider_id)
def get_provider_for_credential(
self, credential: CredentialObject
) -> CredentialProvider | None:
"""
Get the appropriate provider for a credential.
Args:
credential: The credential to find a provider for
Returns:
The provider if found, None otherwise
"""
# First, check if credential specifies a provider
if credential.provider_id:
provider = self._providers.get(credential.provider_id)
if provider:
return provider
# Fall back to finding a provider that supports this type
for provider in self._providers.values():
if provider.can_handle(credential):
return provider
return None
# --- Usage Spec Management ---
def register_usage(self, spec: CredentialUsageSpec) -> None:
"""
Register how a tool uses credentials.
Args:
spec: The usage specification
"""
self._usage_specs[spec.credential_id] = spec
def get_usage_spec(self, credential_id: str) -> CredentialUsageSpec | None:
"""
Get the usage spec for a credential.
Args:
credential_id: The credential identifier
Returns:
The usage spec if registered, None otherwise
"""
return self._usage_specs.get(credential_id)
# --- Credential Access ---
def get_credential(
self,
credential_id: str,
refresh_if_needed: bool = True,
) -> CredentialObject | None:
"""
Get a credential by ID.
Args:
credential_id: The credential identifier
refresh_if_needed: If True, refresh expired credentials
Returns:
CredentialObject or None if not found
"""
with self._lock:
# Check cache
cached = self._get_from_cache(credential_id)
if cached is not None:
if refresh_if_needed and self._should_refresh(cached):
return self._refresh_credential(cached)
return cached
# Load from storage
credential = self._storage.load(credential_id)
if credential is None:
return None
# Refresh if needed
if refresh_if_needed and self._should_refresh(credential):
credential = self._refresh_credential(credential)
# Cache
self._add_to_cache(credential)
return credential
def get_key(self, credential_id: str, key_name: str) -> str | None:
"""
Convenience method to get a specific key value.
Args:
credential_id: The credential identifier
key_name: The key within the credential
Returns:
The key value or None if not found
"""
credential = self.get_credential(credential_id)
if credential is None:
return None
return credential.get_key(key_name)
def get(self, credential_id: str) -> str | None:
"""
Legacy compatibility: get the primary key value.
For single-key credentials, returns that key.
For multi-key, returns 'value', 'api_key', or 'access_token'.
Args:
credential_id: The credential identifier
Returns:
The primary key value or None
"""
credential = self.get_credential(credential_id)
if credential is None:
return None
return credential.get_default_key()
# --- Template Resolution ---
def resolve(self, template: str) -> str:
"""
Resolve credential templates in a string.
Args:
template: String containing {{cred.key}} patterns
Returns:
Template with all references resolved
Example:
>>> store.resolve("Bearer {{github.access_token}}")
"Bearer ghp_xxxxxxxxxxxx"
"""
return self._resolver.resolve(template)
def resolve_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""
Resolve credential templates in headers dictionary.
Args:
headers: Dict of header name to template value
Returns:
Dict with all templates resolved
Example:
>>> store.resolve_headers({
... "Authorization": "Bearer {{github.access_token}}"
... })
{"Authorization": "Bearer ghp_xxx"}
"""
return self._resolver.resolve_headers(headers)
def resolve_params(self, params: dict[str, str]) -> dict[str, str]:
"""
Resolve credential templates in query parameters dictionary.
Args:
params: Dict of param name to template value
Returns:
Dict with all templates resolved
"""
return self._resolver.resolve_params(params)
def resolve_for_usage(self, credential_id: str) -> dict[str, Any]:
"""
Get resolved request kwargs for a registered usage spec.
Args:
credential_id: The credential identifier
Returns:
Dict with 'headers', 'params', etc. keys as appropriate
Raises:
ValueError: If no usage spec is registered for the credential
"""
spec = self._usage_specs.get(credential_id)
if spec is None:
raise ValueError(f"No usage spec registered for '{credential_id}'")
result: dict[str, Any] = {}
if spec.headers:
result["headers"] = self.resolve_headers(spec.headers)
if spec.query_params:
result["params"] = self.resolve_params(spec.query_params)
if spec.body_fields:
result["data"] = {key: self.resolve(value) for key, value in spec.body_fields.items()}
return result
# --- Credential Management ---
def save_credential(self, credential: CredentialObject) -> None:
"""
Save a credential to storage.
Args:
credential: The credential to save
"""
with self._lock:
self._storage.save(credential)
self._add_to_cache(credential)
logger.info(f"Saved credential '{credential.id}'")
def delete_credential(self, credential_id: str) -> bool:
"""
Delete a credential from storage.
Args:
credential_id: The credential identifier
Returns:
True if the credential existed and was deleted
"""
with self._lock:
self._remove_from_cache(credential_id)
result = self._storage.delete(credential_id)
if result:
logger.info(f"Deleted credential '{credential_id}'")
return result
def list_credentials(self) -> list[str]:
"""
List all available credential IDs.
Returns:
List of credential IDs
"""
return self._storage.list_all()
def is_available(self, credential_id: str) -> bool:
"""
Check if a credential is available.
Args:
credential_id: The credential identifier
Returns:
True if credential exists and is accessible
"""
return self.get_credential(credential_id, refresh_if_needed=False) is not None
# --- Validation ---
def validate_for_usage(self, credential_id: str) -> list[str]:
"""
Validate that a credential meets its usage spec requirements.
Args:
credential_id: The credential identifier
Returns:
List of missing keys or errors. Empty list if valid.
"""
spec = self._usage_specs.get(credential_id)
if spec is None:
return [] # No requirements registered
credential = self.get_credential(credential_id)
if credential is None:
return [f"Credential '{credential_id}' not found"]
errors = []
for key_name in spec.required_keys:
if not credential.has_key(key_name):
errors.append(f"Missing required key '{key_name}'")
return errors
def validate_all(self) -> dict[str, list[str]]:
"""
Validate all registered usage specs.
Returns:
Dict mapping credential_id to list of errors.
Only includes credentials with errors.
"""
errors = {}
for cred_id in self._usage_specs.keys():
cred_errors = self.validate_for_usage(cred_id)
if cred_errors:
errors[cred_id] = cred_errors
return errors
def validate_credential(self, credential_id: str) -> bool:
"""
Validate a credential using its provider.
Args:
credential_id: The credential identifier
Returns:
True if credential is valid
"""
credential = self.get_credential(credential_id, refresh_if_needed=False)
if credential is None:
return False
provider = self.get_provider_for_credential(credential)
if provider is None:
# No provider, assume valid if has keys
return bool(credential.keys)
return provider.validate(credential)
# --- Lifecycle Management ---
def _should_refresh(self, credential: CredentialObject) -> bool:
"""Check if credential should be refreshed."""
if not self._auto_refresh:
return False
if not credential.auto_refresh:
return False
provider = self.get_provider_for_credential(credential)
if provider is None:
return False
return provider.should_refresh(credential)
def _refresh_credential(self, credential: CredentialObject) -> CredentialObject:
"""Refresh a credential using its provider."""
provider = self.get_provider_for_credential(credential)
if provider is None:
logger.warning(f"No provider found for credential '{credential.id}'")
return credential
try:
refreshed = provider.refresh(credential)
refreshed.last_refreshed = datetime.now(UTC)
# Persist the refreshed credential
self._storage.save(refreshed)
self._add_to_cache(refreshed)
logger.info(f"Refreshed credential '{credential.id}'")
return refreshed
except CredentialRefreshError as e:
logger.error(f"Failed to refresh credential '{credential.id}': {e}")
return credential
def refresh_credential(self, credential_id: str) -> CredentialObject | None:
"""
Manually refresh a credential.
Args:
credential_id: The credential identifier
Returns:
The refreshed credential, or None if not found
Raises:
CredentialRefreshError: If refresh fails
"""
credential = self.get_credential(credential_id, refresh_if_needed=False)
if credential is None:
return None
return self._refresh_credential(credential)
# --- Caching ---
def _get_from_cache(self, credential_id: str) -> CredentialObject | None:
"""Get credential from cache if not expired."""
if credential_id not in self._cache:
return None
credential, cached_at = self._cache[credential_id]
age = (datetime.now(UTC) - cached_at).total_seconds()
if age > self._cache_ttl:
del self._cache[credential_id]
return None
return credential
def _add_to_cache(self, credential: CredentialObject) -> None:
"""Add credential to cache."""
self._cache[credential.id] = (credential, datetime.now(UTC))
def _remove_from_cache(self, credential_id: str) -> None:
"""Remove credential from cache."""
self._cache.pop(credential_id, None)
def clear_cache(self) -> None:
"""Clear the credential cache."""
with self._lock:
self._cache.clear()
# --- Factory Methods ---
@classmethod
def for_testing(
cls,
credentials: dict[str, dict[str, str]],
) -> CredentialStore:
"""
Create a credential store for testing with mock credentials.
Args:
credentials: Dict mapping credential_id to {key_name: value}
e.g., {"brave_search": {"api_key": "test-key"}}
Returns:
CredentialStore with in-memory credentials
Example:
store = CredentialStore.for_testing({
"brave_search": {"api_key": "test-brave-key"},
"github_oauth": {
"access_token": "test-token",
"refresh_token": "test-refresh"
}
})
"""
# Convert test data to CredentialObjects
cred_objects: dict[str, CredentialObject] = {}
for cred_id, keys in credentials.items():
cred_objects[cred_id] = CredentialObject(
id=cred_id,
keys={k: CredentialKey(name=k, value=SecretStr(v)) for k, v in keys.items()},
)
return cls(
storage=InMemoryStorage(cred_objects),
auto_refresh=False,
)
@classmethod
def with_encrypted_storage(
cls,
base_path: str,
providers: list[CredentialProvider] | None = None,
**kwargs: Any,
) -> CredentialStore:
"""
Create a credential store with encrypted file storage.
Args:
base_path: Directory for credential files
providers: List of credential providers
**kwargs: Additional arguments passed to CredentialStore
Returns:
CredentialStore with EncryptedFileStorage
"""
from .storage import EncryptedFileStorage
return cls(
storage=EncryptedFileStorage(base_path),
providers=providers,
**kwargs,
)
@classmethod
def with_env_storage(
cls,
env_mapping: dict[str, str] | None = None,
providers: list[CredentialProvider] | None = None,
**kwargs: Any,
) -> CredentialStore:
"""
Create a credential store with environment variable storage.
Args:
env_mapping: Map of credential_id -> env_var_name
providers: List of credential providers
**kwargs: Additional arguments passed to CredentialStore
Returns:
CredentialStore with EnvVarStorage
"""
return cls(
storage=EnvVarStorage(env_mapping),
providers=providers,
**kwargs,
)
@classmethod
def with_aden_sync(
cls,
base_url: str = "https://hive.adenhq.com",
cache_ttl_seconds: int = 300,
local_path: str | None = None,
auto_sync: bool = True,
**kwargs: Any,
) -> CredentialStore:
"""
Create a credential store with Aden server sync.
Automatically syncs OAuth2 tokens from the Aden authentication server.
Falls back to local-only storage if ADEN_API_KEY is not set or Aden
is unreachable.
Args:
base_url: Aden server URL (default: https://hive.adenhq.com)
cache_ttl_seconds: How long to cache credentials locally (default: 5 min)
local_path: Path for local credential storage (default: ~/.hive/credentials)
auto_sync: Whether to sync all credentials on startup (default: True)
**kwargs: Additional arguments passed to CredentialStore
Returns:
CredentialStore configured with Aden sync
Example:
# Simple usage - just set ADEN_API_KEY env var
store = CredentialStore.with_aden_sync()
# Get HubSpot token (auto-refreshed via Aden)
token = store.get_key("hubspot", "access_token")
"""
import os
from pathlib import Path
from .storage import EncryptedFileStorage
# Determine local storage path
if local_path is None:
local_path = str(Path.home() / ".hive" / "credentials")
local_storage = EncryptedFileStorage(base_path=local_path)
# Check if Aden is configured
api_key = os.environ.get("ADEN_API_KEY")
if not api_key:
logger.info("ADEN_API_KEY not set, using local-only credential storage")
return cls(storage=local_storage, **kwargs)
# Try to setup Aden sync
try:
from .aden import (
AdenCachedStorage,
AdenClientConfig,
AdenCredentialClient,
AdenSyncProvider,
)
# Create Aden client
client = AdenCredentialClient(AdenClientConfig(base_url=base_url))
# Create sync provider
provider = AdenSyncProvider(client=client)
# Use cached storage for offline resilience
cached_storage = AdenCachedStorage(
local_storage=local_storage,
aden_provider=provider,
cache_ttl_seconds=cache_ttl_seconds,
)
store = cls(
storage=cached_storage,
providers=[provider],
auto_refresh=True,
**kwargs,
)
# Initial sync
if auto_sync:
synced = provider.sync_all(store)
logger.info(f"Synced {synced} credentials from Aden server")
return store
except ImportError:
logger.warning("Aden components not available, using local storage")
return cls(storage=local_storage, **kwargs)
except Exception as e:
logger.warning(f"Failed to setup Aden sync: {e}. Using local storage.")
return cls(storage=local_storage, **kwargs)
+219
View File
@@ -0,0 +1,219 @@
"""
Template resolution system for credential injection.
This module handles {{cred.key}} patterns, enabling the bipartisan model
where tools specify how credentials are used in HTTP requests.
Template Syntax:
{{credential_id.key_name}} - Access specific key
{{credential_id}} - Access default key (value, api_key, or access_token)
Examples:
"Bearer {{github_oauth.access_token}}" -> "Bearer ghp_xxx"
"X-API-Key: {{brave_search.api_key}}" -> "X-API-Key: BSAKxxx"
"{{brave_search}}" -> "BSAKxxx" (uses default key)
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from .models import CredentialKeyNotFoundError, CredentialNotFoundError
if TYPE_CHECKING:
from .store import CredentialStore
class TemplateResolver:
"""
Resolves credential templates like {{cred.key}} into actual values.
Usage:
resolver = TemplateResolver(credential_store)
# Resolve single template string
auth_header = resolver.resolve("Bearer {{github_oauth.access_token}}")
# Resolve all headers at once
headers = resolver.resolve_headers({
"Authorization": "Bearer {{github_oauth.access_token}}",
"X-API-Key": "{{brave_search.api_key}}"
})
"""
# Matches {{credential_id}} or {{credential_id.key_name}}
TEMPLATE_PATTERN = re.compile(r"\{\{([a-zA-Z0-9_-]+)(?:\.([a-zA-Z0-9_-]+))?\}\}")
def __init__(self, credential_store: CredentialStore):
"""
Initialize the template resolver.
Args:
credential_store: The credential store to resolve references against
"""
self._store = credential_store
def resolve(self, template: str, fail_on_missing: bool = True) -> str:
"""
Resolve all credential references in a template string.
Args:
template: String containing {{cred.key}} patterns
fail_on_missing: If True, raise error on missing credentials
Returns:
Template with all references replaced with actual values
Raises:
CredentialNotFoundError: If credential doesn't exist and fail_on_missing=True
CredentialKeyNotFoundError: If key doesn't exist in credential
Example:
>>> resolver.resolve("Bearer {{github_oauth.access_token}}")
"Bearer ghp_xxxxxxxxxxxx"
"""
def replace_match(match: re.Match) -> str:
cred_id = match.group(1)
key_name = match.group(2) # May be None
credential = self._store.get_credential(cred_id, refresh_if_needed=True)
if credential is None:
if fail_on_missing:
raise CredentialNotFoundError(f"Credential '{cred_id}' not found")
return match.group(0) # Return original template
# Get specific key or default
if key_name:
value = credential.get_key(key_name)
if value is None:
raise CredentialKeyNotFoundError(
f"Key '{key_name}' not found in credential '{cred_id}'"
)
else:
# Use default key
value = credential.get_default_key()
if value is None:
raise CredentialKeyNotFoundError(f"Credential '{cred_id}' has no keys")
# Record usage
credential.record_usage()
return value
return self.TEMPLATE_PATTERN.sub(replace_match, template)
def resolve_headers(
self,
header_templates: dict[str, str],
fail_on_missing: bool = True,
) -> dict[str, str]:
"""
Resolve templates in a headers dictionary.
Args:
header_templates: Dict of header name to template value
fail_on_missing: If True, raise error on missing credentials
Returns:
Dict with all templates resolved to actual values
Example:
>>> resolver.resolve_headers({
... "Authorization": "Bearer {{github_oauth.access_token}}",
... "X-API-Key": "{{brave_search.api_key}}"
... })
{"Authorization": "Bearer ghp_xxx", "X-API-Key": "BSAKxxx"}
"""
return {
key: self.resolve(value, fail_on_missing) for key, value in header_templates.items()
}
def resolve_params(
self,
param_templates: dict[str, str],
fail_on_missing: bool = True,
) -> dict[str, str]:
"""
Resolve templates in a query parameters dictionary.
Args:
param_templates: Dict of param name to template value
fail_on_missing: If True, raise error on missing credentials
Returns:
Dict with all templates resolved to actual values
"""
return {key: self.resolve(value, fail_on_missing) for key, value in param_templates.items()}
def has_templates(self, text: str) -> bool:
"""
Check if text contains any credential templates.
Args:
text: String to check
Returns:
True if text contains {{...}} patterns
"""
return bool(self.TEMPLATE_PATTERN.search(text))
def extract_references(self, text: str) -> list[tuple[str, str | None]]:
"""
Extract all credential references from text.
Args:
text: String to extract references from
Returns:
List of (credential_id, key_name) tuples.
key_name is None if only credential_id was specified.
Example:
>>> resolver.extract_references("{{github.token}} and {{brave_search.api_key}}")
[("github", "token"), ("brave_search", "api_key")]
"""
return [(match.group(1), match.group(2)) for match in self.TEMPLATE_PATTERN.finditer(text)]
def validate_references(self, text: str) -> list[str]:
"""
Validate all credential references in text without resolving.
Args:
text: String containing template references
Returns:
List of error messages for invalid references.
Empty list if all references are valid.
"""
errors = []
references = self.extract_references(text)
for cred_id, key_name in references:
credential = self._store.get_credential(cred_id, refresh_if_needed=False)
if credential is None:
errors.append(f"Credential '{cred_id}' not found")
continue
if key_name:
if not credential.has_key(key_name):
errors.append(f"Key '{key_name}' not found in credential '{cred_id}'")
elif not credential.keys:
errors.append(f"Credential '{cred_id}' has no keys")
return errors
def get_required_credentials(self, text: str) -> list[str]:
"""
Get list of credential IDs required by a template string.
Args:
text: String containing template references
Returns:
List of unique credential IDs referenced in the text
"""
references = self.extract_references(text)
return list(dict.fromkeys(cred_id for cred_id, _ in references))
@@ -0,0 +1 @@
"""Tests for the credential store module."""
@@ -0,0 +1,707 @@
"""
Comprehensive tests for the credential store module.
Tests cover:
- Core models (CredentialObject, CredentialKey, CredentialUsageSpec)
- Template resolution
- Storage backends (InMemoryStorage, EnvVarStorage, EncryptedFileStorage)
- Providers (StaticProvider, BearerTokenProvider)
- Main CredentialStore
- OAuth2 module
"""
import os
import tempfile
from datetime import UTC, datetime, timedelta
from pathlib import Path
from unittest.mock import patch
import pytest
from core.framework.credentials import (
CompositeStorage,
CredentialKey,
CredentialKeyNotFoundError,
CredentialNotFoundError,
CredentialObject,
CredentialStore,
CredentialType,
CredentialUsageSpec,
EncryptedFileStorage,
EnvVarStorage,
InMemoryStorage,
StaticProvider,
TemplateResolver,
)
from pydantic import SecretStr
class TestCredentialKey:
"""Tests for CredentialKey model."""
def test_create_basic_key(self):
"""Test creating a basic credential key."""
key = CredentialKey(name="api_key", value=SecretStr("test-value"))
assert key.name == "api_key"
assert key.get_secret_value() == "test-value"
assert key.expires_at is None
assert not key.is_expired
def test_key_with_expiration(self):
"""Test key with expiration time."""
future = datetime.now(UTC) + timedelta(hours=1)
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=future)
assert not key.is_expired
def test_expired_key(self):
"""Test that expired key is detected."""
past = datetime.now(UTC) - timedelta(hours=1)
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)
assert key.is_expired
def test_key_with_metadata(self):
"""Test key with metadata."""
key = CredentialKey(
name="token",
value=SecretStr("xxx"),
metadata={"client_id": "abc", "scope": "read"},
)
assert key.metadata["client_id"] == "abc"
class TestCredentialObject:
"""Tests for CredentialObject model."""
def test_create_simple_credential(self):
"""Test creating a simple API key credential."""
cred = CredentialObject(
id="brave_search",
credential_type=CredentialType.API_KEY,
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("test-key"))},
)
assert cred.id == "brave_search"
assert cred.credential_type == CredentialType.API_KEY
assert cred.get_key("api_key") == "test-key"
def test_create_multi_key_credential(self):
"""Test creating a credential with multiple keys."""
cred = CredentialObject(
id="github_oauth",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(name="access_token", value=SecretStr("ghp_xxx")),
"refresh_token": CredentialKey(name="refresh_token", value=SecretStr("ghr_xxx")),
},
)
assert cred.get_key("access_token") == "ghp_xxx"
assert cred.get_key("refresh_token") == "ghr_xxx"
assert cred.get_key("nonexistent") is None
def test_set_key(self):
"""Test setting a key on a credential."""
cred = CredentialObject(id="test", keys={})
cred.set_key("new_key", "new_value")
assert cred.get_key("new_key") == "new_value"
def test_set_key_with_expiration(self):
"""Test setting a key with expiration."""
cred = CredentialObject(id="test", keys={})
expires = datetime.now(UTC) + timedelta(hours=1)
cred.set_key("token", "xxx", expires_at=expires)
assert cred.keys["token"].expires_at == expires
def test_needs_refresh(self):
"""Test needs_refresh property."""
past = datetime.now(UTC) - timedelta(hours=1)
cred = CredentialObject(
id="test",
keys={"token": CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)},
)
assert cred.needs_refresh
def test_get_default_key(self):
"""Test get_default_key returns appropriate default."""
# With api_key
cred = CredentialObject(
id="test",
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("key-value"))},
)
assert cred.get_default_key() == "key-value"
# With access_token
cred2 = CredentialObject(
id="test",
keys={
"access_token": CredentialKey(name="access_token", value=SecretStr("token-value"))
},
)
assert cred2.get_default_key() == "token-value"
def test_record_usage(self):
"""Test recording credential usage."""
cred = CredentialObject(id="test", keys={})
assert cred.use_count == 0
assert cred.last_used is None
cred.record_usage()
assert cred.use_count == 1
assert cred.last_used is not None
class TestCredentialUsageSpec:
"""Tests for CredentialUsageSpec model."""
def test_create_usage_spec(self):
"""Test creating a usage spec."""
spec = CredentialUsageSpec(
credential_id="brave_search",
required_keys=["api_key"],
headers={"X-Subscription-Token": "{{api_key}}"},
)
assert spec.credential_id == "brave_search"
assert "api_key" in spec.required_keys
assert "{{api_key}}" in spec.headers.values()
class TestInMemoryStorage:
"""Tests for InMemoryStorage."""
def test_save_and_load(self):
"""Test saving and loading a credential."""
storage = InMemoryStorage()
cred = CredentialObject(
id="test",
keys={"key": CredentialKey(name="key", value=SecretStr("value"))},
)
storage.save(cred)
loaded = storage.load("test")
assert loaded is not None
assert loaded.id == "test"
assert loaded.get_key("key") == "value"
def test_load_nonexistent(self):
"""Test loading a nonexistent credential."""
storage = InMemoryStorage()
assert storage.load("nonexistent") is None
def test_delete(self):
"""Test deleting a credential."""
storage = InMemoryStorage()
cred = CredentialObject(id="test", keys={})
storage.save(cred)
assert storage.delete("test")
assert storage.load("test") is None
assert not storage.delete("test")
def test_list_all(self):
"""Test listing all credentials."""
storage = InMemoryStorage()
storage.save(CredentialObject(id="a", keys={}))
storage.save(CredentialObject(id="b", keys={}))
ids = storage.list_all()
assert "a" in ids
assert "b" in ids
def test_exists(self):
"""Test checking if credential exists."""
storage = InMemoryStorage()
storage.save(CredentialObject(id="test", keys={}))
assert storage.exists("test")
assert not storage.exists("nonexistent")
def test_clear(self):
"""Test clearing all credentials."""
storage = InMemoryStorage()
storage.save(CredentialObject(id="test", keys={}))
storage.clear()
assert storage.list_all() == []
class TestEnvVarStorage:
"""Tests for EnvVarStorage."""
def test_load_from_env(self):
"""Test loading credential from environment variable."""
with patch.dict(os.environ, {"TEST_API_KEY": "test-value"}):
storage = EnvVarStorage(env_mapping={"test": "TEST_API_KEY"})
cred = storage.load("test")
assert cred is not None
assert cred.get_key("api_key") == "test-value"
def test_load_nonexistent(self):
"""Test loading when env var is not set."""
storage = EnvVarStorage(env_mapping={"test": "NONEXISTENT_VAR"})
assert storage.load("test") is None
def test_default_env_var_pattern(self):
"""Test default env var naming pattern."""
with patch.dict(os.environ, {"MY_SERVICE_API_KEY": "value"}):
storage = EnvVarStorage()
cred = storage.load("my_service")
assert cred is not None
assert cred.get_key("api_key") == "value"
def test_save_raises(self):
"""Test that save raises NotImplementedError."""
storage = EnvVarStorage()
with pytest.raises(NotImplementedError):
storage.save(CredentialObject(id="test", keys={}))
def test_delete_raises(self):
"""Test that delete raises NotImplementedError."""
storage = EnvVarStorage()
with pytest.raises(NotImplementedError):
storage.delete("test")
class TestEncryptedFileStorage:
"""Tests for EncryptedFileStorage."""
@pytest.fixture
def temp_dir(self):
"""Create a temporary directory for tests."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def storage(self, temp_dir):
"""Create EncryptedFileStorage for tests."""
return EncryptedFileStorage(temp_dir)
def test_save_and_load(self, storage):
"""Test saving and loading encrypted credential."""
cred = CredentialObject(
id="test",
credential_type=CredentialType.API_KEY,
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("secret-value"))},
)
storage.save(cred)
loaded = storage.load("test")
assert loaded is not None
assert loaded.id == "test"
assert loaded.get_key("api_key") == "secret-value"
def test_encryption_key_from_env(self, temp_dir):
"""Test using encryption key from environment variable."""
from cryptography.fernet import Fernet
key = Fernet.generate_key().decode()
with patch.dict(os.environ, {"HIVE_CREDENTIAL_KEY": key}):
storage = EncryptedFileStorage(temp_dir)
cred = CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
)
storage.save(cred)
# Create new storage instance with same key
storage2 = EncryptedFileStorage(temp_dir)
loaded = storage2.load("test")
assert loaded is not None
assert loaded.get_key("k") == "v"
def test_list_all(self, storage):
"""Test listing all credentials."""
storage.save(CredentialObject(id="cred1", keys={}))
storage.save(CredentialObject(id="cred2", keys={}))
ids = storage.list_all()
assert "cred1" in ids
assert "cred2" in ids
def test_delete(self, storage):
"""Test deleting a credential."""
storage.save(CredentialObject(id="test", keys={}))
assert storage.delete("test")
assert storage.load("test") is None
class TestCompositeStorage:
"""Tests for CompositeStorage."""
def test_read_from_primary(self):
"""Test reading from primary storage."""
primary = InMemoryStorage()
primary.save(
CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("primary"))}
)
)
fallback = InMemoryStorage()
fallback.save(
CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
)
)
storage = CompositeStorage(primary, [fallback])
cred = storage.load("test")
# Should get from primary
assert cred.get_key("k") == "primary"
def test_fallback_when_not_in_primary(self):
"""Test fallback when credential not in primary."""
primary = InMemoryStorage()
fallback = InMemoryStorage()
fallback.save(
CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
)
)
storage = CompositeStorage(primary, [fallback])
cred = storage.load("test")
assert cred.get_key("k") == "fallback"
def test_write_to_primary_only(self):
"""Test that writes go to primary only."""
primary = InMemoryStorage()
fallback = InMemoryStorage()
storage = CompositeStorage(primary, [fallback])
storage.save(CredentialObject(id="test", keys={}))
assert primary.exists("test")
assert not fallback.exists("test")
class TestStaticProvider:
"""Tests for StaticProvider."""
def test_provider_id(self):
"""Test provider ID."""
provider = StaticProvider()
assert provider.provider_id == "static"
def test_supported_types(self):
"""Test supported credential types."""
provider = StaticProvider()
assert CredentialType.API_KEY in provider.supported_types
assert CredentialType.CUSTOM in provider.supported_types
def test_refresh_returns_unchanged(self):
"""Test that refresh returns credential unchanged."""
provider = StaticProvider()
cred = CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
)
refreshed = provider.refresh(cred)
assert refreshed.get_key("k") == "v"
def test_validate_with_keys(self):
"""Test validation with keys present."""
provider = StaticProvider()
cred = CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
)
assert provider.validate(cred)
def test_validate_without_keys(self):
"""Test validation without keys."""
provider = StaticProvider()
cred = CredentialObject(id="test", keys={})
assert not provider.validate(cred)
def test_should_refresh(self):
"""Test that static provider never needs refresh."""
provider = StaticProvider()
cred = CredentialObject(id="test", keys={})
assert not provider.should_refresh(cred)
class TestTemplateResolver:
"""Tests for TemplateResolver."""
@pytest.fixture
def store(self):
"""Create a test store with credentials."""
return CredentialStore.for_testing(
{
"brave_search": {"api_key": "test-brave-key"},
"github_oauth": {"access_token": "ghp_xxx", "refresh_token": "ghr_xxx"},
}
)
@pytest.fixture
def resolver(self, store):
"""Create a resolver with the test store."""
return TemplateResolver(store)
def test_resolve_simple(self, resolver):
"""Test resolving a simple template."""
result = resolver.resolve("Bearer {{github_oauth.access_token}}")
assert result == "Bearer ghp_xxx"
def test_resolve_multiple(self, resolver):
"""Test resolving multiple templates."""
result = resolver.resolve("{{github_oauth.access_token}} and {{brave_search.api_key}}")
assert "ghp_xxx" in result
assert "test-brave-key" in result
def test_resolve_default_key(self, resolver):
"""Test resolving credential without key specified."""
result = resolver.resolve("Key: {{brave_search}}")
assert "test-brave-key" in result
def test_resolve_headers(self, resolver):
"""Test resolving headers dict."""
headers = resolver.resolve_headers(
{
"Authorization": "Bearer {{github_oauth.access_token}}",
"X-API-Key": "{{brave_search.api_key}}",
}
)
assert headers["Authorization"] == "Bearer ghp_xxx"
assert headers["X-API-Key"] == "test-brave-key"
def test_resolve_missing_credential(self, resolver):
"""Test error on missing credential."""
with pytest.raises(CredentialNotFoundError):
resolver.resolve("{{nonexistent.key}}")
def test_resolve_missing_key(self, resolver):
"""Test error on missing key."""
with pytest.raises(CredentialKeyNotFoundError):
resolver.resolve("{{github_oauth.nonexistent}}")
def test_has_templates(self, resolver):
"""Test detecting templates in text."""
assert resolver.has_templates("{{cred.key}}")
assert resolver.has_templates("Bearer {{token}}")
assert not resolver.has_templates("no templates here")
def test_extract_references(self, resolver):
"""Test extracting credential references."""
refs = resolver.extract_references("{{github.token}} and {{brave.key}}")
assert ("github", "token") in refs
assert ("brave", "key") in refs
class TestCredentialStore:
"""Tests for CredentialStore."""
def test_for_testing_factory(self):
"""Test creating store for testing."""
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
assert store.get("test") == "value"
assert store.get_key("test", "api_key") == "value"
def test_get_credential(self):
"""Test getting a credential."""
store = CredentialStore.for_testing({"test": {"key": "value"}})
cred = store.get_credential("test")
assert cred is not None
assert cred.get_key("key") == "value"
def test_get_nonexistent(self):
"""Test getting nonexistent credential."""
store = CredentialStore.for_testing({})
assert store.get_credential("nonexistent") is None
assert store.get("nonexistent") is None
def test_save_and_load(self):
"""Test saving and loading a credential."""
store = CredentialStore.for_testing({})
cred = CredentialObject(id="new", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
store.save_credential(cred)
loaded = store.get_credential("new")
assert loaded is not None
assert loaded.get_key("k") == "v"
def test_delete_credential(self):
"""Test deleting a credential."""
store = CredentialStore.for_testing({"test": {"k": "v"}})
assert store.delete_credential("test")
assert store.get_credential("test") is None
def test_list_credentials(self):
"""Test listing all credentials."""
store = CredentialStore.for_testing({"a": {"k": "v"}, "b": {"k": "v"}})
ids = store.list_credentials()
assert "a" in ids
assert "b" in ids
def test_is_available(self):
"""Test checking credential availability."""
store = CredentialStore.for_testing({"test": {"k": "v"}})
assert store.is_available("test")
assert not store.is_available("nonexistent")
def test_resolve_templates(self):
"""Test template resolution through store."""
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
result = store.resolve("Key: {{test.api_key}}")
assert result == "Key: value"
def test_resolve_headers(self):
"""Test resolving headers through store."""
store = CredentialStore.for_testing({"test": {"token": "xxx"}})
headers = store.resolve_headers({"Authorization": "Bearer {{test.token}}"})
assert headers["Authorization"] == "Bearer xxx"
def test_register_provider(self):
"""Test registering a provider."""
store = CredentialStore.for_testing({})
provider = StaticProvider()
store.register_provider(provider)
assert store.get_provider("static") is provider
def test_register_usage_spec(self):
"""Test registering a usage spec."""
store = CredentialStore.for_testing({})
spec = CredentialUsageSpec(
credential_id="test",
required_keys=["api_key"],
headers={"X-Key": "{{api_key}}"},
)
store.register_usage(spec)
assert store.get_usage_spec("test") is spec
def test_validate_for_usage(self):
"""Test validating credential for usage spec."""
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
spec = CredentialUsageSpec(credential_id="test", required_keys=["api_key"])
store.register_usage(spec)
errors = store.validate_for_usage("test")
assert errors == []
def test_validate_for_usage_missing_key(self):
"""Test validation with missing required key."""
store = CredentialStore.for_testing({"test": {"other_key": "value"}})
spec = CredentialUsageSpec(credential_id="test", required_keys=["api_key"])
store.register_usage(spec)
errors = store.validate_for_usage("test")
assert "api_key" in errors[0]
def test_caching(self):
"""Test that credentials are cached."""
storage = InMemoryStorage()
store = CredentialStore(storage=storage, cache_ttl_seconds=60)
storage.save(
CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
)
# First load
store.get_credential("test")
# Delete from storage
storage.delete("test")
# Should still get from cache
cred2 = store.get_credential("test")
assert cred2 is not None
def test_clear_cache(self):
"""Test clearing the cache."""
storage = InMemoryStorage()
store = CredentialStore(storage=storage)
storage.save(CredentialObject(id="test", keys={}))
store.get_credential("test") # Cache it
storage.delete("test")
store.clear_cache()
# Should not find in cache now
assert store.get_credential("test") is None
class TestOAuth2Module:
"""Tests for OAuth2 module."""
def test_oauth2_token_from_response(self):
"""Test creating OAuth2Token from token response."""
from core.framework.credentials.oauth2 import OAuth2Token
response = {
"access_token": "xxx",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "yyy",
"scope": "read write",
}
token = OAuth2Token.from_token_response(response)
assert token.access_token == "xxx"
assert token.token_type == "Bearer"
assert token.refresh_token == "yyy"
assert token.scope == "read write"
assert token.expires_at is not None
def test_token_is_expired(self):
"""Test token expiration check."""
from core.framework.credentials.oauth2 import OAuth2Token
# Not expired
future = datetime.now(UTC) + timedelta(hours=1)
token = OAuth2Token(access_token="xxx", expires_at=future)
assert not token.is_expired
# Expired
past = datetime.now(UTC) - timedelta(hours=1)
expired_token = OAuth2Token(access_token="xxx", expires_at=past)
assert expired_token.is_expired
def test_token_can_refresh(self):
"""Test token refresh capability check."""
from core.framework.credentials.oauth2 import OAuth2Token
with_refresh = OAuth2Token(access_token="xxx", refresh_token="yyy")
assert with_refresh.can_refresh
without_refresh = OAuth2Token(access_token="xxx")
assert not without_refresh.can_refresh
def test_oauth2_config_validation(self):
"""Test OAuth2Config validation."""
from core.framework.credentials.oauth2 import OAuth2Config, TokenPlacement
# Valid config
config = OAuth2Config(
token_url="https://example.com/token", client_id="id", client_secret="secret"
)
assert config.token_url == "https://example.com/token"
# Missing token_url
with pytest.raises(ValueError):
OAuth2Config(token_url="")
# HEADER_CUSTOM without custom_header_name
with pytest.raises(ValueError):
OAuth2Config(
token_url="https://example.com/token",
token_placement=TokenPlacement.HEADER_CUSTOM,
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
@@ -0,0 +1,55 @@
"""
HashiCorp Vault integration for the credential store.
This module provides enterprise-grade secret management through
HashiCorp Vault integration.
Quick Start:
from core.framework.credentials import CredentialStore
from core.framework.credentials.vault import HashiCorpVaultStorage
# Configure Vault storage
storage = HashiCorpVaultStorage(
url="https://vault.example.com:8200",
# token read from VAULT_TOKEN env var
mount_point="secret",
path_prefix="hive/agents/prod"
)
# Create credential store with Vault backend
store = CredentialStore(storage=storage)
# Use normally - credentials are stored in Vault
credential = store.get_credential("my_api")
Requirements:
pip install hvac
Authentication:
Set the VAULT_TOKEN environment variable or pass the token directly:
export VAULT_TOKEN="hvs.xxxxxxxxxxxxx"
For production, consider using Vault auth methods:
- Kubernetes auth
- AppRole auth
- AWS IAM auth
Vault Configuration:
Ensure KV v2 secrets engine is enabled:
vault secrets enable -path=secret kv-v2
Grant appropriate policies:
path "secret/data/hive/credentials/*" {
capabilities = ["create", "read", "update", "delete", "list"]
}
path "secret/metadata/hive/credentials/*" {
capabilities = ["list", "delete"]
}
"""
from .hashicorp import HashiCorpVaultStorage
__all__ = ["HashiCorpVaultStorage"]
@@ -0,0 +1,394 @@
"""
HashiCorp Vault storage adapter.
Provides integration with HashiCorp Vault for enterprise secret management.
Requires the 'hvac' package: pip install hvac
"""
from __future__ import annotations
import logging
import os
from datetime import datetime
from typing import Any
from pydantic import SecretStr
from ..models import CredentialKey, CredentialObject, CredentialType
from ..storage import CredentialStorage
logger = logging.getLogger(__name__)
class HashiCorpVaultStorage(CredentialStorage):
"""
HashiCorp Vault storage adapter.
Features:
- KV v2 secrets engine support
- Namespace support (Enterprise)
- Automatic secret versioning
- Audit logging via Vault
The adapter stores credentials in Vault's KV v2 secrets engine with
the following structure:
{mount_point}/data/{path_prefix}/{credential_id}
data:
_type: "oauth2"
access_token: "xxx"
refresh_token: "yyy"
_expires_access_token: "2024-01-26T12:00:00"
_provider_id: "oauth2"
Example:
storage = HashiCorpVaultStorage(
url="https://vault.example.com:8200",
token="hvs.xxx", # Or use VAULT_TOKEN env var
mount_point="secret",
path_prefix="hive/credentials"
)
store = CredentialStore(storage=storage)
# Credentials are now stored in Vault
store.save_credential(credential)
credential = store.get_credential("my_api")
Authentication:
The adapter uses token-based authentication. The token can be provided:
1. Directly via the 'token' parameter
2. Via the VAULT_TOKEN environment variable
For production, consider using:
- Kubernetes auth method
- AppRole auth method
- AWS IAM auth method
Requirements:
pip install hvac
"""
def __init__(
self,
url: str,
token: str | None = None,
mount_point: str = "secret",
path_prefix: str = "hive/credentials",
namespace: str | None = None,
verify_ssl: bool = True,
):
"""
Initialize Vault storage.
Args:
url: Vault server URL (e.g., https://vault.example.com:8200)
token: Vault token. If None, reads from VAULT_TOKEN env var
mount_point: KV secrets engine mount point (default: "secret")
path_prefix: Path prefix for all credentials
namespace: Vault namespace (Enterprise feature)
verify_ssl: Whether to verify SSL certificates
Raises:
ImportError: If hvac is not installed
ValueError: If authentication fails
"""
try:
import hvac
except ImportError as e:
raise ImportError(
"HashiCorp Vault support requires 'hvac'. Install with: pip install hvac"
) from e
self._url = url
self._token = token or os.environ.get("VAULT_TOKEN")
self._mount = mount_point
self._prefix = path_prefix
self._namespace = namespace
if not self._token:
raise ValueError(
"Vault token required. Set VAULT_TOKEN env var or pass token parameter."
)
self._client = hvac.Client(
url=url,
token=self._token,
namespace=namespace,
verify=verify_ssl,
)
if not self._client.is_authenticated():
raise ValueError("Vault authentication failed. Check token and server URL.")
logger.info(f"Connected to HashiCorp Vault at {url}")
def _path(self, credential_id: str) -> str:
"""Build Vault path for credential."""
# Sanitize credential_id
safe_id = credential_id.replace("/", "_").replace("\\", "_")
return f"{self._prefix}/{safe_id}"
def save(self, credential: CredentialObject) -> None:
"""Save credential to Vault KV v2."""
path = self._path(credential.id)
data = self._serialize_for_vault(credential)
try:
self._client.secrets.kv.v2.create_or_update_secret(
path=path,
secret=data,
mount_point=self._mount,
)
logger.debug(f"Saved credential '{credential.id}' to Vault at {path}")
except Exception as e:
logger.error(f"Failed to save credential '{credential.id}' to Vault: {e}")
raise
def load(self, credential_id: str) -> CredentialObject | None:
"""Load credential from Vault."""
path = self._path(credential_id)
try:
response = self._client.secrets.kv.v2.read_secret_version(
path=path,
mount_point=self._mount,
)
data = response["data"]["data"]
return self._deserialize_from_vault(credential_id, data)
except Exception as e:
# Check if it's a "not found" error
error_str = str(e).lower()
if "not found" in error_str or "404" in error_str:
logger.debug(f"Credential '{credential_id}' not found in Vault")
return None
logger.error(f"Failed to load credential '{credential_id}' from Vault: {e}")
raise
def delete(self, credential_id: str) -> bool:
"""Delete credential from Vault (all versions)."""
path = self._path(credential_id)
try:
self._client.secrets.kv.v2.delete_metadata_and_all_versions(
path=path,
mount_point=self._mount,
)
logger.debug(f"Deleted credential '{credential_id}' from Vault")
return True
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str or "404" in error_str:
return False
logger.error(f"Failed to delete credential '{credential_id}' from Vault: {e}")
raise
def list_all(self) -> list[str]:
"""List all credentials under the prefix."""
try:
response = self._client.secrets.kv.v2.list_secrets(
path=self._prefix,
mount_point=self._mount,
)
keys = response.get("data", {}).get("keys", [])
# Remove trailing slashes from folder names
return [k.rstrip("/") for k in keys]
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str or "404" in error_str:
return []
logger.error(f"Failed to list credentials from Vault: {e}")
raise
def exists(self, credential_id: str) -> bool:
"""Check if credential exists in Vault."""
try:
path = self._path(credential_id)
self._client.secrets.kv.v2.read_secret_version(
path=path,
mount_point=self._mount,
)
return True
except Exception:
return False
def _serialize_for_vault(self, credential: CredentialObject) -> dict[str, Any]:
"""Convert credential to Vault secret format."""
data: dict[str, Any] = {
"_type": credential.credential_type.value,
}
if credential.provider_id:
data["_provider_id"] = credential.provider_id
if credential.description:
data["_description"] = credential.description
if credential.auto_refresh:
data["_auto_refresh"] = "true"
# Store each key
for key_name, key in credential.keys.items():
data[key_name] = key.get_secret_value()
if key.expires_at:
data[f"_expires_{key_name}"] = key.expires_at.isoformat()
if key.metadata:
data[f"_metadata_{key_name}"] = str(key.metadata)
return data
def _deserialize_from_vault(self, credential_id: str, data: dict[str, Any]) -> CredentialObject:
"""Reconstruct credential from Vault secret."""
# Extract metadata fields
cred_type = CredentialType(data.pop("_type", "api_key"))
provider_id = data.pop("_provider_id", None)
description = data.pop("_description", "")
auto_refresh = data.pop("_auto_refresh", "") == "true"
# Build keys dict
keys: dict[str, CredentialKey] = {}
# Find all non-metadata keys
key_names = [k for k in data.keys() if not k.startswith("_")]
for key_name in key_names:
value = data[key_name]
# Check for expiration
expires_at = None
expires_key = f"_expires_{key_name}"
if expires_key in data:
try:
expires_at = datetime.fromisoformat(data[expires_key])
except (ValueError, TypeError):
pass
# Check for metadata
metadata: dict[str, Any] = {}
metadata_key = f"_metadata_{key_name}"
if metadata_key in data:
try:
import ast
metadata = ast.literal_eval(data[metadata_key])
except (ValueError, SyntaxError):
pass
keys[key_name] = CredentialKey(
name=key_name,
value=SecretStr(value),
expires_at=expires_at,
metadata=metadata,
)
return CredentialObject(
id=credential_id,
credential_type=cred_type,
keys=keys,
provider_id=provider_id,
description=description,
auto_refresh=auto_refresh,
)
# --- Vault-Specific Operations ---
def get_secret_metadata(self, credential_id: str) -> dict[str, Any] | None:
"""
Get Vault metadata for a secret (version info, timestamps, etc.).
Args:
credential_id: The credential identifier
Returns:
Metadata dict or None if not found
"""
path = self._path(credential_id)
try:
response = self._client.secrets.kv.v2.read_secret_metadata(
path=path,
mount_point=self._mount,
)
return response.get("data", {})
except Exception:
return None
def soft_delete(self, credential_id: str, versions: list[int] | None = None) -> bool:
"""
Soft delete specific versions (can be recovered).
Args:
credential_id: The credential identifier
versions: Version numbers to delete. If None, deletes latest.
Returns:
True if successful
"""
path = self._path(credential_id)
try:
if versions:
self._client.secrets.kv.v2.delete_secret_versions(
path=path,
versions=versions,
mount_point=self._mount,
)
else:
self._client.secrets.kv.v2.delete_latest_version_of_secret(
path=path,
mount_point=self._mount,
)
return True
except Exception as e:
logger.error(f"Soft delete failed for '{credential_id}': {e}")
return False
def undelete(self, credential_id: str, versions: list[int]) -> bool:
"""
Recover soft-deleted versions.
Args:
credential_id: The credential identifier
versions: Version numbers to recover
Returns:
True if successful
"""
path = self._path(credential_id)
try:
self._client.secrets.kv.v2.undelete_secret_versions(
path=path,
versions=versions,
mount_point=self._mount,
)
return True
except Exception as e:
logger.error(f"Undelete failed for '{credential_id}': {e}")
return False
def load_version(self, credential_id: str, version: int) -> CredentialObject | None:
"""
Load a specific version of a credential.
Args:
credential_id: The credential identifier
version: Version number to load
Returns:
CredentialObject or None
"""
path = self._path(credential_id)
try:
response = self._client.secrets.kv.v2.read_secret_version(
path=path,
version=version,
mount_point=self._mount,
)
data = response["data"]["data"]
return self._deserialize_from_vault(credential_id, data)
except Exception:
return None
+17 -16
View File
@@ -1,32 +1,32 @@
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
from framework.graph.goal import Goal, SuccessCriterion, Constraint, GoalStatus
from framework.graph.node import NodeSpec, NodeContext, NodeResult, NodeProtocol
from framework.graph.edge import EdgeSpec, EdgeCondition
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.executor import GraphExecutor
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
from framework.graph.judge import HybridJudge, create_default_judge
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
# Flexible execution (Worker-Judge pattern)
from framework.graph.plan import (
Plan,
PlanStep,
ActionSpec,
ActionType,
StepStatus,
Judgment,
JudgmentAction,
EvaluationRule,
PlanExecutionResult,
ExecutionStatus,
load_export,
# HITL (Human-in-the-loop)
ApprovalDecision,
ApprovalRequest,
ApprovalResult,
EvaluationRule,
ExecutionStatus,
Judgment,
JudgmentAction,
Plan,
PlanExecutionResult,
PlanStep,
StepStatus,
load_export,
)
from framework.graph.judge import HybridJudge, create_default_judge
from framework.graph.worker_node import WorkerNode, StepExecutionResult
from framework.graph.flexible_executor import FlexibleGraphExecutor, ExecutorConfig
from framework.graph.code_sandbox import CodeSandbox, safe_exec, safe_eval
from framework.graph.worker_node import StepExecutionResult, WorkerNode
__all__ = [
# Goal
@@ -42,6 +42,7 @@ __all__ = [
# Edge
"EdgeSpec",
"EdgeCondition",
"GraphSpec",
# Executor (fixed graph)
"GraphExecutor",
# Plan (flexible execution)
+14 -14
View File
@@ -13,11 +13,11 @@ Security measures:
"""
import ast
import sys
import signal
from typing import Any
from dataclasses import dataclass, field
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
# Safe builtins whitelist
SAFE_BUILTINS = {
@@ -25,7 +25,6 @@ SAFE_BUILTINS = {
"True": True,
"False": False,
"None": None,
# Type constructors
"bool": bool,
"int": int,
@@ -36,7 +35,6 @@ SAFE_BUILTINS = {
"set": set,
"tuple": tuple,
"frozenset": frozenset,
# Basic functions
"abs": abs,
"all": all,
@@ -97,22 +95,26 @@ BLOCKED_AST_NODES = {
class CodeSandboxError(Exception):
"""Error during sandboxed code execution."""
pass
class TimeoutError(CodeSandboxError):
"""Code execution timed out."""
pass
class SecurityError(CodeSandboxError):
"""Code contains potentially dangerous operations."""
pass
@dataclass
class SandboxResult:
"""Result of sandboxed code execution."""
success: bool
result: Any = None
error: str | None = None
@@ -134,6 +136,7 @@ class RestrictedImporter:
if name not in self._cache:
import importlib
self._cache[name] = importlib.import_module(name)
return self._cache[name]
@@ -161,9 +164,8 @@ class CodeValidator:
for node in ast.walk(tree):
# Check for blocked node types
if type(node) in self.blocked_nodes:
issues.append(
f"Blocked operation: {type(node).__name__} at line {getattr(node, 'lineno', '?')}"
)
lineno = getattr(node, "lineno", "?")
issues.append(f"Blocked operation: {type(node).__name__} at line {lineno}")
# Check for dangerous attribute access
if isinstance(node, ast.Attribute):
@@ -212,11 +214,12 @@ class CodeSandbox:
@contextmanager
def _timeout_context(self, seconds: int):
"""Context manager for timeout enforcement."""
def handler(signum, frame):
raise TimeoutError(f"Code execution timed out after {seconds} seconds")
# Only works on Unix-like systems
if hasattr(signal, 'SIGALRM'):
if hasattr(signal, "SIGALRM"):
old_handler = signal.signal(signal.SIGALRM, handler)
signal.alarm(seconds)
try:
@@ -275,6 +278,7 @@ class CodeSandbox:
# Capture stdout
import io
old_stdout = sys.stdout
sys.stdout = captured_stdout = io.StringIO()
@@ -296,11 +300,7 @@ class CodeSandbox:
# Also extract any new variables (not in inputs or builtins)
for key, value in namespace.items():
if (
key not in inputs
and key not in self.safe_builtins
and not key.startswith("_")
):
if key not in inputs and key not in self.safe_builtins and not key.startswith("_"):
extracted[key] = value
return SandboxResult(
+90 -51
View File
@@ -11,9 +11,10 @@ our edges can be created dynamically by a Builder agent based on the goal.
Edge Types:
- always: Always traverse after source completes
- always: Always traverse after source completes
- on_success: Traverse only if source succeeds
- on_failure: Traverse only if source fails
- conditional: Traverse based on expression evaluation
- conditional: Traverse based on expression evaluation (SAFE SUBSET ONLY)
- llm_decide: Let LLM decide based on goal and context (goal-aware routing)
The llm_decide condition is particularly powerful for goal-driven agents,
@@ -21,19 +22,22 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
given the current goal, context, and execution state.
"""
from typing import Any
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from framework.graph.safe_eval import safe_eval
class EdgeCondition(str, Enum):
"""When an edge should be traversed."""
ALWAYS = "always" # Always after source completes
ON_SUCCESS = "on_success" # Only if source succeeds
ON_FAILURE = "on_failure" # Only if source fails
CONDITIONAL = "conditional" # Based on expression
LLM_DECIDE = "llm_decide" # Let LLM decide based on goal and context
ALWAYS = "always" # Always after source completes
ON_SUCCESS = "on_success" # Only if source succeeds
ON_FAILURE = "on_failure" # Only if source fails
CONDITIONAL = "conditional" # Based on expression
LLM_DECIDE = "llm_decide" # Let LLM decide based on goal and context
class EdgeSpec(BaseModel):
@@ -68,6 +72,7 @@ class EdgeSpec(BaseModel):
description="Only filter if results need refinement to meet goal",
)
"""
id: str
source: str = Field(description="Source node ID")
target: str = Field(description="Target node ID")
@@ -76,20 +81,17 @@ class EdgeSpec(BaseModel):
condition: EdgeCondition = EdgeCondition.ALWAYS
condition_expr: str | None = Field(
default=None,
description="Expression for CONDITIONAL edges, e.g., 'output.confidence > 0.8'"
description="Expression for CONDITIONAL edges, e.g., 'output.confidence > 0.8'",
)
# Data flow
input_mapping: dict[str, str] = Field(
default_factory=dict,
description="Map source outputs to target inputs: {target_key: source_key}"
description="Map source outputs to target inputs: {target_key: source_key}",
)
# Priority for multiple outgoing edges
priority: int = Field(
default=0,
description="Higher priority edges are evaluated first"
)
priority: int = Field(default=0, description="Higher priority edges are evaluated first")
# Metadata
description: str = ""
@@ -164,17 +166,18 @@ class EdgeSpec(BaseModel):
"output": output,
"memory": memory,
"result": output.get("result"),
"true": True, # Allow lowercase true/false in conditions
"true": True, # Allow lowercase true/false in conditions
"false": False,
**memory, # Unpack memory keys directly into context
}
try:
# Safe evaluation (in production, use a proper expression evaluator)
return bool(eval(self.condition_expr, {"__builtins__": {}}, context))
# Safe evaluation using AST-based whitelist
return bool(safe_eval(self.condition_expr, context))
except Exception as e:
# Log the error for debugging
import logging
logger = logging.getLogger(__name__)
logger.warning(f" ⚠ Condition evaluation failed: {self.condition_expr}")
logger.warning(f" Error: {e}")
@@ -235,7 +238,8 @@ Respond with ONLY a JSON object:
# Parse response
import re
json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
proceed = data.get("proceed", False)
@@ -243,6 +247,7 @@ Respond with ONLY a JSON object:
# Log the decision (using basic print for now)
import logging
logger = logging.getLogger(__name__)
logger.info(f" 🤔 LLM routing decision: {'PROCEED' if proceed else 'SKIP'}")
logger.info(f" Reason: {reasoning}")
@@ -252,6 +257,7 @@ Respond with ONLY a JSON object:
except Exception as e:
# Fallback: proceed on success
import logging
logger = logging.getLogger(__name__)
logger.warning(f" ⚠ LLM routing failed, defaulting to on_success: {e}")
return source_success
@@ -304,28 +310,24 @@ class AsyncEntryPointSpec(BaseModel):
isolation_level="shared",
)
"""
id: str = Field(description="Unique identifier for this entry point")
name: str = Field(description="Human-readable name")
entry_node: str = Field(description="Node ID to start execution from")
trigger_type: str = Field(
default="manual",
description="How this entry point is triggered: webhook, api, timer, event, manual"
description="How this entry point is triggered: webhook, api, timer, event, manual",
)
trigger_config: dict[str, Any] = Field(
default_factory=dict,
description="Trigger-specific configuration (e.g., webhook URL, timer interval)"
description="Trigger-specific configuration (e.g., webhook URL, timer interval)",
)
isolation_level: str = Field(
default="shared",
description="State isolation: isolated, shared, or synchronized"
)
priority: int = Field(
default=0,
description="Execution priority (higher = more priority)"
default="shared", description="State isolation: isolated, shared, or synchronized"
)
priority: int = Field(default=0, description="Execution priority (higher = more priority)")
max_concurrent: int = Field(
default=10,
description="Maximum concurrent executions for this entry point"
default=10, description="Maximum concurrent executions for this entry point"
)
model_config = {"extra": "allow"}
@@ -370,6 +372,7 @@ class GraphSpec(BaseModel):
edges=[...],
)
"""
id: str
goal_id: str
version: str = "1.0.0"
@@ -378,46 +381,43 @@ class GraphSpec(BaseModel):
entry_node: str = Field(description="ID of the first node to execute")
entry_points: dict[str, str] = Field(
default_factory=dict,
description="Named entry points for resuming execution. Format: {name: node_id}"
description="Named entry points for resuming execution. Format: {name: node_id}",
)
async_entry_points: list[AsyncEntryPointSpec] = Field(
default_factory=list,
description="Asynchronous entry points for concurrent execution streams (used with AgentRuntime)"
description=(
"Asynchronous entry points for concurrent execution streams (used with AgentRuntime)"
),
)
terminal_nodes: list[str] = Field(
default_factory=list,
description="IDs of nodes that end execution"
default_factory=list, description="IDs of nodes that end execution"
)
pause_nodes: list[str] = Field(
default_factory=list,
description="IDs of nodes that pause execution for HITL input"
default_factory=list, description="IDs of nodes that pause execution for HITL input"
)
# Components
nodes: list[Any] = Field( # NodeSpec, but avoiding circular import
default_factory=list,
description="All node specifications"
)
edges: list[EdgeSpec] = Field(
default_factory=list,
description="All edge specifications"
default_factory=list, description="All node specifications"
)
edges: list[EdgeSpec] = Field(default_factory=list, description="All edge specifications")
# Shared memory keys
memory_keys: list[str] = Field(
default_factory=list,
description="Keys available in shared memory"
default_factory=list, description="Keys available in shared memory"
)
# Default LLM settings
default_model: str = "claude-haiku-4-5-20251001"
max_tokens: int = 1024
# Cleanup LLM for JSON extraction fallback (fast/cheap model preferred)
# If not set, uses CEREBRAS_API_KEY -> cerebras/llama-3.3-70b or
# ANTHROPIC_API_KEY -> claude-3-5-haiku as fallback
cleanup_llm_model: str | None = None
# Execution limits
max_steps: int = Field(
default=100,
description="Maximum node executions before timeout"
)
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
max_retries_per_node: int = 3
# Metadata
@@ -453,6 +453,42 @@ class GraphSpec(BaseModel):
"""Get all edges entering a node."""
return [e for e in self.edges if e.target == node_id]
def detect_fan_out_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that fan-out to multiple targets.
A fan-out occurs when a node has multiple outgoing edges with the same
condition (typically ON_SUCCESS) that should execute in parallel.
Returns:
Dict mapping source_node_id -> list of parallel target_node_ids
"""
fan_outs: dict[str, list[str]] = {}
for node in self.nodes:
outgoing = self.get_outgoing_edges(node.id)
# Fan-out: multiple edges with ON_SUCCESS condition
success_edges = [e for e in outgoing if e.condition == EdgeCondition.ON_SUCCESS]
if len(success_edges) > 1:
fan_outs[node.id] = [e.target for e in success_edges]
return fan_outs
def detect_fan_in_nodes(self) -> dict[str, list[str]]:
"""
Detect nodes that receive from multiple sources (fan-in / convergence).
A fan-in occurs when a node has multiple incoming edges, meaning
it should wait for all predecessor branches to complete.
Returns:
Dict mapping target_node_id -> list of source_node_ids
"""
fan_ins: dict[str, list[str]] = {}
for node in self.nodes:
incoming = self.get_incoming_edges(node.id)
if len(incoming) > 1:
fan_ins[node.id] = [e.source for e in incoming]
return fan_ins
def get_entry_point(self, session_state: dict | None = None) -> str:
"""
Get the appropriate entry point based on session state.
@@ -504,7 +540,8 @@ class GraphSpec(BaseModel):
# Check entry node exists
if not self.get_node(entry_point.entry_node):
errors.append(
f"Async entry point '{entry_point.id}' references missing node '{entry_point.entry_node}'"
f"Async entry point '{entry_point.id}' references "
f"missing node '{entry_point.entry_node}'"
)
# Validate isolation level
@@ -562,11 +599,13 @@ class GraphSpec(BaseModel):
for node in self.nodes:
if node.id not in reachable:
# Skip this error if the node is a pause node, entry point target, or async entry point
# (pause/resume architecture and async entry points make these reachable)
if (node.id in self.pause_nodes or
node.id in self.entry_points.values() or
node.id in async_entry_nodes):
# Skip if node is a pause node, entry point target, or async entry
# (pause/resume architecture and async entry points make reachable)
if (
node.id in self.pause_nodes
or node.id in self.entry_points.values()
or node.id in async_entry_nodes
):
continue
errors.append(f"Node '{node.id}' is unreachable from entry")
+425 -47
View File
@@ -9,31 +9,34 @@ The executor:
5. Returns the final result
"""
import asyncio
import logging
from typing import Any, Callable
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from framework.runtime.core import Runtime
from framework.graph.edge import EdgeSpec, GraphSpec
from framework.graph.goal import Goal
from framework.graph.node import (
NodeSpec,
NodeContext,
NodeResult,
NodeProtocol,
SharedMemory,
LLMNode,
RouterNode,
FunctionNode,
LLMNode,
NodeContext,
NodeProtocol,
NodeResult,
NodeSpec,
RouterNode,
SharedMemory,
)
from framework.graph.edge import GraphSpec
from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
from framework.graph.validator import OutputValidator
from framework.graph.output_cleaner import OutputCleaner, CleansingConfig
from framework.llm.provider import LLMProvider, Tool
from framework.runtime.core import Runtime
@dataclass
class ExecutionResult:
"""Result of executing a graph."""
success: bool
output: dict[str, Any] = field(default_factory=dict)
error: str | None = None
@@ -45,6 +48,35 @@ class ExecutionResult:
session_state: dict[str, Any] = field(default_factory=dict) # State to resume from
@dataclass
class ParallelBranch:
"""Tracks a single branch in parallel fan-out execution."""
branch_id: str
node_id: str
edge: EdgeSpec
result: "NodeResult | None" = None
status: str = "pending" # pending, running, completed, failed
retry_count: int = 0
error: str | None = None
@dataclass
class ParallelExecutionConfig:
"""Configuration for parallel execution behavior."""
# Error handling: "fail_all" cancels all on first failure,
# "continue_others" lets remaining branches complete,
# "wait_all" waits for all and reports all failures
on_branch_failure: str = "fail_all"
# Memory conflict handling when branches write same key
memory_conflict_strategy: str = "last_wins" # "last_wins", "first_wins", "error"
# Timeout per branch in seconds
branch_timeout_seconds: float = 300.0
class GraphExecutor:
"""
Executes agent graphs.
@@ -73,6 +105,8 @@ class GraphExecutor:
node_registry: dict[str, NodeProtocol] | None = None,
approval_callback: Callable | None = None,
cleansing_config: CleansingConfig | None = None,
enable_parallel_execution: bool = True,
parallel_config: ParallelExecutionConfig | None = None,
):
"""
Initialize the executor.
@@ -85,6 +119,8 @@ class GraphExecutor:
node_registry: Custom node implementations by ID
approval_callback: Optional callback for human-in-the-loop approval
cleansing_config: Optional output cleansing configuration
enable_parallel_execution: Enable parallel fan-out execution (default True)
parallel_config: Configuration for parallel execution behavior
"""
self.runtime = runtime
self.llm = llm
@@ -102,6 +138,10 @@ class GraphExecutor:
llm_provider=llm,
)
# Parallel execution settings
self.enable_parallel_execution = enable_parallel_execution
self._parallel_config = parallel_config or ParallelExecutionConfig()
def _validate_tools(self, graph: GraphSpec) -> list[str]:
"""
Validate that all tools declared by nodes are available.
@@ -116,14 +156,15 @@ class GraphExecutor:
if node.tools:
missing = set(node.tools) - available_tool_names
if missing:
available = sorted(available_tool_names) if available_tool_names else "none"
errors.append(
f"Node '{node.name}' (id={node.id}) requires tools {sorted(missing)} "
f"but they are not registered. Available tools: {sorted(available_tool_names) if available_tool_names else 'none'}"
f"Node '{node.name}' (id={node.id}) requires tools "
f"{sorted(missing)} but they are not registered. "
f"Available tools: {available}"
)
return errors
async def execute(
self,
graph: GraphSpec,
@@ -159,7 +200,10 @@ class GraphExecutor:
self.logger.error(f"{err}")
return ExecutionResult(
success=False,
error=f"Missing tools: {'; '.join(tool_errors)}. Register tools via ToolRegistry or remove tool declarations from nodes.",
error=(
f"Missing tools: {'; '.join(tool_errors)}. "
"Register tools via ToolRegistry or remove tool declarations from nodes."
),
)
# Initialize execution state
@@ -167,10 +211,18 @@ class GraphExecutor:
# Restore session state if provided
if session_state and "memory" in session_state:
# Restore memory from previous session
for key, value in session_state["memory"].items():
memory.write(key, value)
self.logger.info(f"📥 Restored session state with {len(session_state['memory'])} memory keys")
memory_data = session_state["memory"]
# [RESTORED] Type safety check
if not isinstance(memory_data, dict):
self.logger.warning(
f"⚠️ Invalid memory data type in session state: "
f"{type(memory_data).__name__}, expected dict"
)
else:
# Restore memory from previous session
for key, value in memory_data.items():
memory.write(key, value)
self.logger.info(f"📥 Restored session state with {len(memory_data)} memory keys")
# Write new input data to memory (each key individually)
if input_data:
@@ -227,6 +279,7 @@ class GraphExecutor:
memory=memory,
goal=goal,
input_data=input_data or {},
max_tokens=graph.max_tokens,
)
# Log actual input data being read
@@ -242,7 +295,7 @@ class GraphExecutor:
self.logger.info(f" {key}: {value_str}")
# Get or create node implementation
node_impl = self._get_node_implementation(node_spec)
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
# Validate inputs
validation_errors = node_impl.validate_input(ctx)
@@ -264,6 +317,7 @@ class GraphExecutor:
output=result.output,
expected_keys=node_spec.output_keys,
check_hallucination=True,
nullable_keys=node_spec.nullable_output_keys,
)
if not validation.success:
self.logger.error(f" ✗ Output validation failed: {validation.error}")
@@ -276,7 +330,10 @@ class GraphExecutor:
)
if result.success:
self.logger.info(f" ✓ Success (tokens: {result.tokens_used}, latency: {result.latency_ms}ms)")
self.logger.info(
f" ✓ Success (tokens: {result.tokens_used}, "
f"latency: {result.latency_ms}ms)"
)
# Generate and log human-readable summary
summary = result.to_summary(node_spec)
@@ -299,28 +356,55 @@ class GraphExecutor:
# Handle failure
if not result.success:
# Track retries per node
node_retry_counts[current_node_id] = node_retry_counts.get(current_node_id, 0) + 1
node_retry_counts[current_node_id] = (
node_retry_counts.get(current_node_id, 0) + 1
)
if node_retry_counts[current_node_id] < node_spec.max_retries:
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
max_retries = getattr(node_spec, "max_retries", 3)
if node_retry_counts[current_node_id] < max_retries:
# Retry - don't increment steps for retries
steps -= 1
self.logger.info(f" ↻ Retrying ({node_retry_counts[current_node_id]}/{node_spec.max_retries})...")
# --- EXPONENTIAL BACKOFF ---
retry_count = node_retry_counts[current_node_id]
# Backoff formula: 1.0 * (2^(retry - 1)) -> 1s, 2s, 4s...
delay = 1.0 * (2 ** (retry_count - 1))
self.logger.info(f" Using backoff: Sleeping {delay}s before retry...")
await asyncio.sleep(delay)
# --------------------------------------
self.logger.info(
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
)
continue
else:
# Max retries exceeded - fail the execution
self.logger.error(f" ✗ Max retries ({node_spec.max_retries}) exceeded for node {current_node_id}")
self.logger.error(
f" ✗ Max retries ({max_retries}) exceeded for node {current_node_id}"
)
self.runtime.report_problem(
severity="critical",
description=f"Node {current_node_id} failed after {node_spec.max_retries} attempts: {result.error}",
description=(
f"Node {current_node_id} failed after "
f"{max_retries} attempts: {result.error}"
),
)
self.runtime.end_run(
success=False,
output_data=memory.read_all(),
narrative=f"Failed at {node_spec.name} after {node_spec.max_retries} retries: {result.error}",
narrative=(
f"Failed at {node_spec.name} after "
f"{max_retries} retries: {result.error}"
),
)
return ExecutionResult(
success=False,
error=f"Node '{node_spec.name}' failed after {node_spec.max_retries} attempts: {result.error}",
error=(
f"Node '{node_spec.name}' failed after "
f"{max_retries} attempts: {result.error}"
),
output=memory.read_all(),
steps_executed=steps,
total_tokens=total_tokens,
@@ -368,8 +452,8 @@ class GraphExecutor:
self.logger.info(f" → Router directing to: {result.next_node}")
current_node_id = result.next_node
else:
# Follow edges
next_node = self._follow_edges(
# Get all traversable edges for fan-out detection
traversable_edges = self._get_all_traversable_edges(
graph=graph,
goal=goal,
current_node_id=current_node_id,
@@ -377,12 +461,59 @@ class GraphExecutor:
result=result,
memory=memory,
)
if next_node is None:
if not traversable_edges:
self.logger.info(" → No more edges, ending execution")
break # No valid edge, end execution
next_spec = graph.get_node(next_node)
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
current_node_id = next_node
# Check for fan-out (multiple traversable edges)
if self.enable_parallel_execution and len(traversable_edges) > 1:
# Find convergence point (fan-in node)
targets = [e.target for e in traversable_edges]
fan_in_node = self._find_convergence_node(graph, targets)
# Execute branches in parallel
(
_branch_results,
branch_tokens,
branch_latency,
) = await self._execute_parallel_branches(
graph=graph,
goal=goal,
edges=traversable_edges,
memory=memory,
source_result=result,
source_node_spec=node_spec,
path=path,
)
total_tokens += branch_tokens
total_latency += branch_latency
# Continue from fan-in node
if fan_in_node:
self.logger.info(f" ⑃ Fan-in: converging at {fan_in_node}")
current_node_id = fan_in_node
else:
# No convergence point - branches are terminal
self.logger.info(" → Parallel branches completed (no convergence)")
break
else:
# Sequential: follow single edge (existing logic via _follow_edges)
next_node = self._follow_edges(
graph=graph,
goal=goal,
current_node_id=current_node_id,
current_node_spec=node_spec,
result=result,
memory=memory,
)
if next_node is None:
self.logger.info(" → No more edges, ending execution")
break
next_spec = graph.get_node(next_node)
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
current_node_id = next_node
# Update input_data for next node
input_data = result.output
@@ -433,6 +564,7 @@ class GraphExecutor:
memory: SharedMemory,
goal: Goal,
input_data: dict[str, Any],
max_tokens: int = 4096,
) -> NodeContext:
"""Build execution context for a node."""
# Filter tools to those available to this node
@@ -456,12 +588,15 @@ class GraphExecutor:
available_tools=available_tools,
goal_context=goal.to_prompt_context(),
goal=goal, # Pass Goal object for LLM-powered routers
max_tokens=max_tokens,
)
# Valid node types - no ambiguous "llm" type allowed
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
def _get_node_implementation(self, node_spec: NodeSpec) -> NodeProtocol:
def _get_node_implementation(
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
) -> NodeProtocol:
"""Get or create a node implementation."""
# Check registry first
if node_spec.id in self.node_registry:
@@ -482,10 +617,18 @@ class GraphExecutor:
f"Node '{node_spec.id}' is type 'llm_tool_use' but declares no tools. "
"Either add tools to the node or change type to 'llm_generate'."
)
return LLMNode(tool_executor=self.tool_executor, require_tools=True)
return LLMNode(
tool_executor=self.tool_executor,
require_tools=True,
cleanup_llm_model=cleanup_llm_model,
)
if node_spec.node_type == "llm_generate":
return LLMNode(tool_executor=None, require_tools=False)
return LLMNode(
tool_executor=None,
require_tools=False,
cleanup_llm_model=cleanup_llm_model,
)
if node_spec.node_type == "router":
return RouterNode()
@@ -493,13 +636,16 @@ class GraphExecutor:
if node_spec.node_type == "function":
# Function nodes need explicit registration
raise RuntimeError(
f"Function node '{node_spec.id}' not registered. "
"Register with node_registry."
f"Function node '{node_spec.id}' not registered. Register with node_registry."
)
if node_spec.node_type == "human_input":
# Human input nodes are handled specially by HITL mechanism
return LLMNode(tool_executor=None, require_tools=False)
return LLMNode(
tool_executor=None,
require_tools=False,
cleanup_llm_model=cleanup_llm_model,
)
# Should never reach here due to validation above
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
@@ -539,9 +685,7 @@ class GraphExecutor:
)
if not validation.valid:
self.logger.warning(
f"⚠ Output validation failed: {validation.errors}"
)
self.logger.warning(f"⚠ Output validation failed: {validation.errors}")
# Clean the output
cleaned_output = self.output_cleaner.clean_output(
@@ -554,9 +698,9 @@ class GraphExecutor:
# Update result with cleaned output
result.output = cleaned_output
# Write cleaned output back to memory
# Write cleaned output back to memory (skip validation for LLM output)
for key, value in cleaned_output.items():
memory.write(key, value)
memory.write(key, value, validate=False)
# Revalidate
revalidation = self.output_cleaner.validate_output(
@@ -573,15 +717,249 @@ class GraphExecutor:
)
# Continue anyway if fallback_to_raw is True
# Map inputs
# Map inputs (skip validation for processed LLM output)
mapped = edge.map_inputs(result.output, memory.read_all())
for key, value in mapped.items():
memory.write(key, value)
memory.write(key, value, validate=False)
return edge.target
return None
def _get_all_traversable_edges(
self,
graph: GraphSpec,
goal: Goal,
current_node_id: str,
current_node_spec: Any,
result: NodeResult,
memory: SharedMemory,
) -> list[EdgeSpec]:
"""
Get ALL edges that should be traversed (for fan-out detection).
Unlike _follow_edges which returns the first match, this returns
all matching edges to enable parallel execution.
"""
edges = graph.get_outgoing_edges(current_node_id)
traversable = []
for edge in edges:
target_node_spec = graph.get_node(edge.target)
if edge.should_traverse(
source_success=result.success,
source_output=result.output,
memory=memory.read_all(),
llm=self.llm,
goal=goal,
source_node_name=current_node_spec.name if current_node_spec else current_node_id,
target_node_name=target_node_spec.name if target_node_spec else edge.target,
):
traversable.append(edge)
return traversable
def _find_convergence_node(
self,
graph: GraphSpec,
parallel_targets: list[str],
) -> str | None:
"""
Find the common target node where parallel branches converge (fan-in).
Args:
graph: The graph specification
parallel_targets: List of node IDs that are running in parallel
Returns:
Node ID where all branches converge, or None if no convergence
"""
# Get all nodes that parallel branches lead to
next_nodes: dict[str, int] = {} # node_id -> count of branches leading to it
for target in parallel_targets:
outgoing = graph.get_outgoing_edges(target)
for edge in outgoing:
next_nodes[edge.target] = next_nodes.get(edge.target, 0) + 1
# Convergence node is where ALL branches lead
for node_id, count in next_nodes.items():
if count == len(parallel_targets):
return node_id
# Fallback: return most common target if any
if next_nodes:
return max(next_nodes.keys(), key=lambda k: next_nodes[k])
return None
async def _execute_parallel_branches(
self,
graph: GraphSpec,
goal: Goal,
edges: list[EdgeSpec],
memory: SharedMemory,
source_result: NodeResult,
source_node_spec: Any,
path: list[str],
) -> tuple[dict[str, NodeResult], int, int]:
"""
Execute multiple branches in parallel using asyncio.gather.
Args:
graph: The graph specification
goal: The execution goal
edges: List of edges to follow in parallel
memory: Shared memory instance
source_result: Result from the source node
source_node_spec: Spec of the source node
path: Execution path list to update
Returns:
Tuple of (branch_results dict, total_tokens, total_latency)
"""
branches: dict[str, ParallelBranch] = {}
# Create branches for each edge
for edge in edges:
branch_id = f"{edge.source}_to_{edge.target}"
branches[branch_id] = ParallelBranch(
branch_id=branch_id,
node_id=edge.target,
edge=edge,
)
self.logger.info(f" ⑂ Fan-out: executing {len(branches)} branches in parallel")
for branch in branches.values():
target_spec = graph.get_node(branch.node_id)
self.logger.info(f"{target_spec.name if target_spec else branch.node_id}")
async def execute_single_branch(
branch: ParallelBranch,
) -> tuple[ParallelBranch, NodeResult | Exception]:
"""Execute a single branch with retry logic."""
node_spec = graph.get_node(branch.node_id)
if node_spec is None:
branch.status = "failed"
branch.error = f"Node {branch.node_id} not found in graph"
return branch, RuntimeError(branch.error)
branch.status = "running"
try:
# Validate and clean output before mapping inputs (same as _follow_edges)
if self.cleansing_config.enabled and node_spec:
validation = self.output_cleaner.validate_output(
output=source_result.output,
source_node_id=source_node_spec.id if source_node_spec else "unknown",
target_node_spec=node_spec,
)
if not validation.valid:
self.logger.warning(
f"⚠ Output validation failed for branch "
f"{branch.node_id}: {validation.errors}"
)
cleaned_output = self.output_cleaner.clean_output(
output=source_result.output,
source_node_id=source_node_spec.id if source_node_spec else "unknown",
target_node_spec=node_spec,
validation_errors=validation.errors,
)
# Write cleaned output to memory
for key, value in cleaned_output.items():
await memory.write_async(key, value)
# Map inputs via edge
mapped = branch.edge.map_inputs(source_result.output, memory.read_all())
for key, value in mapped.items():
await memory.write_async(key, value)
# Execute with retries
last_result = None
for attempt in range(node_spec.max_retries):
branch.retry_count = attempt
# Build context for this branch
ctx = self._build_context(node_spec, memory, goal, mapped, graph.max_tokens)
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
self.logger.info(
f" ▶ Branch {node_spec.name}: executing (attempt {attempt + 1})"
)
result = await node_impl.execute(ctx)
last_result = result
if result.success:
# Write outputs to shared memory using async write
for key, value in result.output.items():
await memory.write_async(key, value)
branch.result = result
branch.status = "completed"
self.logger.info(
f" ✓ Branch {node_spec.name}: success "
f"(tokens: {result.tokens_used}, latency: {result.latency_ms}ms)"
)
return branch, result
self.logger.warning(
f" ↻ Branch {node_spec.name}: "
f"retry {attempt + 1}/{node_spec.max_retries}"
)
# All retries exhausted
branch.status = "failed"
branch.error = last_result.error if last_result else "Unknown error"
branch.result = last_result
self.logger.error(
f" ✗ Branch {node_spec.name}: "
f"failed after {node_spec.max_retries} attempts"
)
return branch, last_result
except Exception as e:
branch.status = "failed"
branch.error = str(e)
self.logger.error(f" ✗ Branch {branch.node_id}: exception - {e}")
return branch, e
# Execute all branches concurrently
tasks = [execute_single_branch(b) for b in branches.values()]
results = await asyncio.gather(*tasks, return_exceptions=False)
# Process results
total_tokens = 0
total_latency = 0
branch_results: dict[str, NodeResult] = {}
failed_branches: list[ParallelBranch] = []
for branch, result in results:
path.append(branch.node_id)
if isinstance(result, Exception):
failed_branches.append(branch)
elif result is None or not result.success:
failed_branches.append(branch)
else:
total_tokens += result.tokens_used
total_latency += result.latency_ms
branch_results[branch.branch_id] = result
# Handle failures based on config
if failed_branches:
failed_names = [graph.get_node(b.node_id).name for b in failed_branches]
if self._parallel_config.on_branch_failure == "fail_all":
raise RuntimeError(f"Parallel execution failed: branches {failed_names} failed")
elif self._parallel_config.on_branch_failure == "continue_others":
self.logger.warning(
f"⚠ Some branches failed ({failed_names}), continuing with successful ones"
)
self.logger.info(
f" ⑃ Fan-out complete: {len(branch_results)}/{len(branches)} branches succeeded"
)
return branch_results, total_tokens, total_latency
def register_node(self, node_id: str, implementation: NodeProtocol) -> None:
"""Register a custom node implementation."""
self.node_registry[node_id] = implementation
+31 -17
View File
@@ -15,28 +15,29 @@ using a Worker-Judge loop:
This keeps planning external while execution/evaluation is internal.
"""
from typing import Any, Callable
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from framework.runtime.core import Runtime
from framework.graph.code_sandbox import CodeSandbox
from framework.graph.goal import Goal
from framework.graph.judge import HybridJudge, create_default_judge
from framework.graph.plan import (
Plan,
PlanStep,
PlanExecutionResult,
ExecutionStatus,
StepStatus,
Judgment,
JudgmentAction,
ApprovalDecision,
ApprovalRequest,
ApprovalResult,
ApprovalDecision,
ExecutionStatus,
Judgment,
JudgmentAction,
Plan,
PlanExecutionResult,
PlanStep,
StepStatus,
)
from framework.graph.judge import HybridJudge, create_default_judge
from framework.graph.worker_node import WorkerNode, StepExecutionResult
from framework.graph.code_sandbox import CodeSandbox
from framework.graph.worker_node import StepExecutionResult, WorkerNode
from framework.llm.provider import LLMProvider, Tool
from framework.runtime.core import Runtime
# Type alias for approval callback
ApprovalCallback = Callable[[ApprovalRequest], ApprovalResult]
@@ -45,6 +46,7 @@ ApprovalCallback = Callable[[ApprovalRequest], ApprovalResult]
@dataclass
class ExecutorConfig:
"""Configuration for FlexibleGraphExecutor."""
max_retries_per_step: int = 3
max_total_steps: int = 100
timeout_seconds: int = 300
@@ -165,7 +167,10 @@ class FlexibleGraphExecutor:
status=ExecutionStatus.NEEDS_REPLAN,
plan=plan,
context=context,
feedback="No executable steps available but plan not complete. Check dependencies.",
feedback=(
"No executable steps available but plan not complete. "
"Check dependencies."
),
steps_executed=steps_executed,
total_tokens=total_tokens,
total_latency=total_latency,
@@ -174,7 +179,8 @@ class FlexibleGraphExecutor:
# Execute next step (for now, sequential; could be parallel)
step = ready_steps[0]
# Debug: show ready steps
# print(f" [DEBUG] Ready steps: {[s.id for s in ready_steps]}, executing: {step.id}")
# ready_ids = [s.id for s in ready_steps]
# print(f" [DEBUG] Ready steps: {ready_ids}, executing: {step.id}")
# APPROVAL CHECK - before execution
if step.requires_approval:
@@ -360,7 +366,10 @@ class FlexibleGraphExecutor:
status=ExecutionStatus.NEEDS_REPLAN,
plan=plan,
context=context,
feedback=f"Step '{step.id}' failed after {step.attempts} attempts: {judgment.feedback}",
feedback=(
f"Step '{step.id}' failed after {step.attempts} attempts: "
f"{judgment.feedback}"
),
steps_executed=steps_executed,
total_tokens=total_tokens,
total_latency=total_latency,
@@ -450,12 +459,17 @@ class FlexibleGraphExecutor:
preview_parts.append(f"Tool: {step.action.tool_name}")
if step.action.tool_args:
import json
args_preview = json.dumps(step.action.tool_args, indent=2, default=str)
if len(args_preview) > 500:
args_preview = args_preview[:500] + "..."
preview_parts.append(f"Args: {args_preview}")
elif step.action.prompt:
prompt_preview = step.action.prompt[:300] + "..." if len(step.action.prompt) > 300 else step.action.prompt
prompt_preview = (
step.action.prompt[:300] + "..."
if len(step.action.prompt) > 300
else step.action.prompt
)
preview_parts.append(f"Prompt: {prompt_preview}")
# Include step inputs resolved from context (what will be sent/used)
+20 -31
View File
@@ -12,20 +12,21 @@ Goals are:
"""
from datetime import datetime
from typing import Any
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class GoalStatus(str, Enum):
"""Lifecycle status of a goal."""
DRAFT = "draft" # Being defined
READY = "ready" # Ready for agent creation
ACTIVE = "active" # Has an agent graph, can execute
COMPLETED = "completed" # Achieved
FAILED = "failed" # Could not be achieved
SUSPENDED = "suspended" # Paused for revision
DRAFT = "draft" # Being defined
READY = "ready" # Ready for agent creation
ACTIVE = "active" # Has an agent graph, can execute
COMPLETED = "completed" # Achieved
FAILED = "failed" # Could not be achieved
SUSPENDED = "suspended" # Paused for revision
class SuccessCriterion(BaseModel):
@@ -37,22 +38,14 @@ class SuccessCriterion(BaseModel):
- Measurable: Can be evaluated programmatically or by LLM
- Achievable: Within the agent's capabilities
"""
id: str
description: str = Field(
description="Human-readable description of what success looks like"
)
description: str = Field(description="Human-readable description of what success looks like")
metric: str = Field(
description="How to measure: 'output_contains', 'output_equals', 'llm_judge', 'custom'"
)
target: Any = Field(
description="The target value or condition"
)
weight: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="Relative importance (0-1)"
)
target: Any = Field(description="The target value or condition")
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="Relative importance (0-1)")
met: bool = False
model_config = {"extra": "allow"}
@@ -66,18 +59,17 @@ class Constraint(BaseModel):
- Hard: Violation means failure
- Soft: Violation is discouraged but allowed
"""
id: str
description: str
constraint_type: str = Field(
description="Type: 'hard' (must not violate) or 'soft' (prefer not to violate)"
)
category: str = Field(
default="general",
description="Category: 'time', 'cost', 'safety', 'scope', 'quality'"
default="general", description="Category: 'time', 'cost', 'safety', 'scope', 'quality'"
)
check: str = Field(
default="",
description="How to check: expression, function name, or 'llm_judge'"
default="", description="How to check: expression, function name, or 'llm_judge'"
)
model_config = {"extra": "allow"}
@@ -119,6 +111,7 @@ class Goal(BaseModel):
]
)
"""
id: str
name: str
description: str
@@ -133,23 +126,19 @@ class Goal(BaseModel):
# Context for the agent
context: dict[str, Any] = Field(
default_factory=dict,
description="Additional context: domain knowledge, user preferences, etc."
description="Additional context: domain knowledge, user preferences, etc.",
)
# Capabilities required
required_capabilities: list[str] = Field(
default_factory=list,
description="What the agent needs: 'llm', 'web_search', 'code_execution', etc."
description="What the agent needs: 'llm', 'web_search', 'code_execution', etc.",
)
# Input/output schema
input_schema: dict[str, Any] = Field(
default_factory=dict,
description="Expected input format"
)
input_schema: dict[str, Any] = Field(default_factory=dict, description="Expected input format")
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Expected output format"
default_factory=dict, description="Expected output format"
)
# Versioning for evolution
+12 -7
View File
@@ -12,6 +12,7 @@ from typing import Any
class HITLInputType(str, Enum):
"""Type of input expected from human."""
FREE_TEXT = "free_text" # Open-ended text response
STRUCTURED = "structured" # Specific fields to fill
SELECTION = "selection" # Choose from options
@@ -22,6 +23,7 @@ class HITLInputType(str, Enum):
@dataclass
class HITLQuestion:
"""A single question to ask the human."""
id: str
question: str
input_type: HITLInputType = HITLInputType.FREE_TEXT
@@ -44,6 +46,7 @@ class HITLRequest:
This is what the agent produces when it needs human input.
"""
# Context
objective: str # What we're trying to accomplish
current_state: str # Where we are in the process
@@ -92,6 +95,7 @@ class HITLResponse:
This is what gets passed back when resuming from a pause.
"""
# Original request reference
request_id: str
@@ -170,13 +174,13 @@ class HITLProtocol:
# Use Haiku to extract answers
try:
import anthropic
import json
questions_str = "\n".join([
f"{i+1}. {q.question} (id: {q.id})"
for i, q in enumerate(request.questions)
])
import anthropic
questions_str = "\n".join(
[f"{i + 1}. {q.question} (id: {q.id})" for i, q in enumerate(request.questions)]
)
prompt = f"""Parse the user's response and extract answers for each question.
@@ -195,13 +199,14 @@ Example format:
message = client.messages.create(
model="claude-3-5-haiku-20241022",
max_tokens=500,
messages=[{"role": "user", "content": prompt}]
messages=[{"role": "user", "content": prompt}],
)
# Parse Haiku's response
import re
response_text = message.content[0].text.strip()
json_match = re.search(r'\{[^{}]*\}', response_text, re.DOTALL)
json_match = re.search(r"\{[^{}]*\}", response_text, re.DOTALL)
if json_match:
parsed = json.loads(json_match.group())
+64 -47
View File
@@ -8,23 +8,24 @@ The HybridJudge evaluates step execution results using:
Escalation path: rules LLM human
"""
from typing import Any
from dataclasses import dataclass, field
from typing import Any
from framework.graph.code_sandbox import safe_eval
from framework.graph.goal import Goal
from framework.graph.plan import (
PlanStep,
EvaluationRule,
Judgment,
JudgmentAction,
EvaluationRule,
PlanStep,
)
from framework.graph.goal import Goal
from framework.graph.code_sandbox import safe_eval
from framework.llm.provider import LLMProvider
@dataclass
class RuleEvaluationResult:
"""Result of rule-based evaluation."""
is_definitive: bool # True if a rule matched definitively
judgment: Judgment | None = None
context: dict[str, Any] = field(default_factory=dict)
@@ -136,9 +137,9 @@ class HybridJudge:
# Build evaluation context
eval_context = {
"step": step.model_dump() if hasattr(step, 'model_dump') else step,
"step": step.model_dump() if hasattr(step, "model_dump") else step,
"result": result,
"goal": goal.model_dump() if hasattr(goal, 'model_dump') else goal,
"goal": goal.model_dump() if hasattr(goal, "model_dump") else goal,
"context": context,
"success": isinstance(result, dict) and result.get("success", False),
"error": isinstance(result, dict) and result.get("error"),
@@ -216,7 +217,10 @@ class HybridJudge:
# Low confidence - escalate
return Judgment(
action=JudgmentAction.ESCALATE,
reasoning=f"LLM confidence ({judgment.confidence:.2f}) below threshold ({self.llm_confidence_threshold})",
reasoning=(
f"LLM confidence ({judgment.confidence:.2f}) "
f"below threshold ({self.llm_confidence_threshold})"
),
feedback=judgment.feedback,
confidence=judgment.confidence,
llm_used=True,
@@ -338,52 +342,65 @@ def create_default_judge(llm: LLMProvider | None = None) -> HybridJudge:
judge = HybridJudge(llm=llm)
# Rule: Accept on explicit success flag
judge.add_rule(EvaluationRule(
id="explicit_success",
description="Step explicitly marked as successful",
condition="isinstance(result, dict) and result.get('success') == True",
action=JudgmentAction.ACCEPT,
priority=100,
))
judge.add_rule(
EvaluationRule(
id="explicit_success",
description="Step explicitly marked as successful",
condition="isinstance(result, dict) and result.get('success') == True",
action=JudgmentAction.ACCEPT,
priority=100,
)
)
# Rule: Retry on transient errors
judge.add_rule(EvaluationRule(
id="transient_error_retry",
description="Transient error that can be retried",
condition="isinstance(result, dict) and result.get('error_type') in ['timeout', 'rate_limit', 'connection_error']",
action=JudgmentAction.RETRY,
feedback_template="Transient error: {result[error]}. Please retry.",
priority=90,
))
judge.add_rule(
EvaluationRule(
id="transient_error_retry",
description="Transient error that can be retried",
condition=(
"isinstance(result, dict) and "
"result.get('error_type') in ['timeout', 'rate_limit', 'connection_error']"
),
action=JudgmentAction.RETRY,
feedback_template="Transient error: {result[error]}. Please retry.",
priority=90,
)
)
# Rule: Replan on missing data
judge.add_rule(EvaluationRule(
id="missing_data_replan",
description="Required data not available",
condition="isinstance(result, dict) and result.get('error_type') == 'missing_data'",
action=JudgmentAction.REPLAN,
feedback_template="Missing required data: {result[error]}. Plan needs adjustment.",
priority=80,
))
judge.add_rule(
EvaluationRule(
id="missing_data_replan",
description="Required data not available",
condition="isinstance(result, dict) and result.get('error_type') == 'missing_data'",
action=JudgmentAction.REPLAN,
feedback_template="Missing required data: {result[error]}. Plan needs adjustment.",
priority=80,
)
)
# Rule: Escalate on security issues
judge.add_rule(EvaluationRule(
id="security_escalate",
description="Security issue detected",
condition="isinstance(result, dict) and result.get('error_type') == 'security'",
action=JudgmentAction.ESCALATE,
feedback_template="Security issue detected: {result[error]}",
priority=200,
))
judge.add_rule(
EvaluationRule(
id="security_escalate",
description="Security issue detected",
condition="isinstance(result, dict) and result.get('error_type') == 'security'",
action=JudgmentAction.ESCALATE,
feedback_template="Security issue detected: {result[error]}",
priority=200,
)
)
# Rule: Fail on max retries exceeded
judge.add_rule(EvaluationRule(
id="max_retries_fail",
description="Maximum retries exceeded",
condition="step.get('attempts', 0) >= step.get('max_retries', 3)",
action=JudgmentAction.REPLAN,
feedback_template="Step '{step[id]}' failed after {step[attempts]} attempts",
priority=150,
))
judge.add_rule(
EvaluationRule(
id="max_retries_fail",
description="Maximum retries exceeded",
condition="step.get('attempts', 0) >= step.get('max_retries', 3)",
action=JudgmentAction.REPLAN,
feedback_template="Step '{step[id]}' failed after {step[attempts]} attempts",
priority=150,
)
)
return judge
File diff suppressed because it is too large Load Diff
+87 -57
View File
@@ -16,6 +16,50 @@ from typing import Any
logger = logging.getLogger(__name__)
def _heuristic_repair(text: str) -> dict | None:
"""
Attempt to repair JSON without an LLM call.
Handles common errors:
- Markdown code blocks
- Python booleans/None (True -> true)
- Single quotes instead of double quotes
"""
if not isinstance(text, str):
return None
# 1. Strip Markdown code blocks
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE)
text = re.sub(r"\s*```$", "", text, flags=re.MULTILINE)
text = text.strip()
# 2. Find outermost JSON-like structure (greedy match)
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
if match:
candidate = match.group(1)
# 3. Common fixes
# Fix Python constants
candidate = re.sub(r"\bTrue\b", "true", candidate)
candidate = re.sub(r"\bFalse\b", "false", candidate)
candidate = re.sub(r"\bNone\b", "null", candidate)
# 4. Attempt load
try:
return json.loads(candidate)
except json.JSONDecodeError:
# 5. Advanced: Try swapping single quotes if double quotes fail
# This is risky but effective for simple dicts
try:
if "'" in candidate and '"' not in candidate:
candidate_swapped = candidate.replace("'", '"')
return json.loads(candidate_swapped)
except json.JSONDecodeError:
pass
return None
@dataclass
class CleansingConfig:
"""Configuration for output cleansing."""
@@ -42,30 +86,8 @@ class OutputCleaner:
"""
Framework-level output validation and cleaning.
Uses fast LLM (llama-3.3-70b) to clean malformed outputs
Uses heuristics and fast LLM to clean malformed outputs
before they flow to the next node.
Example:
cleaner = OutputCleaner(
config=CleansingConfig(enabled=True),
llm_provider=llm,
)
# Validate output
validation = cleaner.validate_output(
output=node_output,
source_node_id="analyze",
target_node_spec=next_node_spec,
)
if not validation.valid:
# Clean the output
cleaned = cleaner.clean_output(
output=node_output,
source_node_id="analyze",
target_node_spec=next_node_spec,
validation_errors=validation.errors,
)
"""
def __init__(self, config: CleansingConfig, llm_provider=None):
@@ -74,8 +96,7 @@ class OutputCleaner:
Args:
config: Cleansing configuration
llm_provider: Optional LLM provider. If None and cleaning is enabled,
will create a LiteLLMProvider with the configured fast_model.
llm_provider: Optional LLM provider.
"""
self.config = config
self.success_cache: dict[str, Any] = {} # Cache successful patterns
@@ -88,9 +109,10 @@ class OutputCleaner:
elif config.enabled:
# Create dedicated fast LLM provider for cleaning
try:
from framework.llm.litellm import LiteLLMProvider
import os
from framework.llm.litellm import LiteLLMProvider
api_key = os.environ.get("CEREBRAS_API_KEY")
if api_key:
self.llm = LiteLLMProvider(
@@ -98,13 +120,9 @@ class OutputCleaner:
model=config.fast_model,
temperature=0.0, # Deterministic cleaning
)
logger.info(
f"✓ Initialized OutputCleaner with {config.fast_model}"
)
logger.info(f"✓ Initialized OutputCleaner with {config.fast_model}")
else:
logger.warning(
"⚠ CEREBRAS_API_KEY not found, output cleaning will be disabled"
)
logger.warning("⚠ CEREBRAS_API_KEY not found, output cleaning will be disabled")
self.llm = None
except ImportError:
logger.warning("⚠ LiteLLMProvider not available, output cleaning disabled")
@@ -121,11 +139,6 @@ class OutputCleaner:
"""
Validate output matches target node's expected input schema.
Args:
output: Output from source node
source_node_id: ID of source node
target_node_spec: Spec of target node (for input_keys)
Returns:
ValidationResult with errors and optionally cleaned output
"""
@@ -199,7 +212,7 @@ class OutputCleaner:
validation_errors: list[str],
) -> dict[str, Any]:
"""
Use fast LLM to clean malformed output.
Use heuristics and fast LLM to clean malformed output.
Args:
output: Raw output from source node
@@ -209,14 +222,36 @@ class OutputCleaner:
Returns:
Cleaned output matching target schema
Raises:
Exception: If cleaning fails and fallback_to_raw is False
"""
if not self.config.enabled:
logger.warning("⚠ Output cleansing disabled in config")
return output
# --- PHASE 1: Fast Heuristic Repair (Avoids LLM call) ---
# Often the output is just a string containing JSON, or has minor syntax errors
# If output is a dictionary but malformed, we might need to serialize it first
# to try and fix the underlying string representation if it came from raw text
# Heuristic: Check if any value is actually a JSON string that should be promoted
# This handles the "JSON Parsing Trap" where LLM returns {"key": "{\"nested\": ...}"}
heuristic_fixed = False
fixed_output = output.copy()
for key, value in output.items():
if isinstance(value, str):
repaired = _heuristic_repair(value)
if repaired and isinstance(repaired, dict | list):
# Check if this repaired structure looks like what we want
# e.g. if the key is 'data' and the string contained valid JSON
fixed_output[key] = repaired
heuristic_fixed = True
# If we fixed something, re-validate manually to see if it's enough
if heuristic_fixed:
logger.info("⚡ Heuristic repair applied (nested JSON expansion)")
return fixed_output
# --- PHASE 2: LLM-based Repair ---
if not self.llm:
logger.warning("⚠ No LLM provider available for cleansing")
return output
@@ -253,22 +288,21 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
response = self.llm.complete(
messages=[{"role": "user", "content": prompt}],
system="You clean malformed agent outputs. Return only valid JSON matching the schema.",
system=(
"You clean malformed agent outputs. Return only valid JSON matching the schema."
),
max_tokens=2048, # Sufficient for cleaning most outputs
)
# Parse cleaned output
cleaned_text = response.content.strip()
# Remove markdown if present
if cleaned_text.startswith("```"):
match = re.search(
r"```(?:json)?\s*\n?(.*?)\n?```", cleaned_text, re.DOTALL
)
if match:
cleaned_text = match.group(1).strip()
# Apply heuristic repair to the LLM's output too (just in case)
cleaned = _heuristic_repair(cleaned_text)
cleaned = json.loads(cleaned_text)
if not cleaned:
# Fallback to standard load if heuristic returns None (unlikely for LLM output)
cleaned = json.loads(cleaned_text)
if isinstance(cleaned, dict):
self.cleansing_count += 1
@@ -278,15 +312,11 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
)
return cleaned
else:
logger.warning(
f"⚠ Cleaned output is not a dict: {type(cleaned)}"
)
logger.warning(f"⚠ Cleaned output is not a dict: {type(cleaned)}")
if self.config.fallback_to_raw:
return output
else:
raise ValueError(
f"Cleaning produced {type(cleaned)}, expected dict"
)
raise ValueError(f"Cleaning produced {type(cleaned)}, expected dict")
except json.JSONDecodeError as e:
logger.error(f"✗ Failed to parse cleaned JSON: {e}")
@@ -318,7 +348,7 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
line = f' "{key}": {type_hint}'
if description:
line += f' // {description}'
line += f" // {description}"
if required:
line += " (required)"
lines.append(line + ",")
+98 -31
View File
@@ -10,24 +10,26 @@ The Plan is the contract between the external planner and the executor:
- If replanning needed, returns feedback to external planner
"""
from typing import Any
from enum import Enum
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ActionType(str, Enum):
"""Types of actions a PlanStep can perform."""
LLM_CALL = "llm_call" # Call LLM for generation
TOOL_USE = "tool_use" # Use a registered tool
SUB_GRAPH = "sub_graph" # Execute a sub-graph
FUNCTION = "function" # Call a Python function
LLM_CALL = "llm_call" # Call LLM for generation
TOOL_USE = "tool_use" # Use a registered tool
SUB_GRAPH = "sub_graph" # Execute a sub-graph
FUNCTION = "function" # Call a Python function
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
class StepStatus(str, Enum):
"""Status of a plan step."""
PENDING = "pending"
AWAITING_APPROVAL = "awaiting_approval" # Waiting for human approval
IN_PROGRESS = "in_progress"
@@ -36,17 +38,36 @@ class StepStatus(str, Enum):
SKIPPED = "skipped"
REJECTED = "rejected" # Human rejected execution
def is_terminal(self) -> bool:
"""Check if this status represents a terminal (finished) state.
Terminal states are states where the step will not execute further,
either because it completed successfully or failed/was skipped.
"""
return self in (
StepStatus.COMPLETED,
StepStatus.FAILED,
StepStatus.SKIPPED,
StepStatus.REJECTED,
)
def is_successful(self) -> bool:
"""Check if this status represents successful completion."""
return self == StepStatus.COMPLETED
class ApprovalDecision(str, Enum):
"""Human decision on a step requiring approval."""
APPROVE = "approve" # Execute as planned
REJECT = "reject" # Skip this step
MODIFY = "modify" # Execute with modifications
ABORT = "abort" # Stop entire execution
APPROVE = "approve" # Execute as planned
REJECT = "reject" # Skip this step
MODIFY = "modify" # Execute with modifications
ABORT = "abort" # Stop entire execution
class ApprovalRequest(BaseModel):
"""Request for human approval before executing a step."""
step_id: str
step_description: str
action_type: str
@@ -62,6 +83,7 @@ class ApprovalRequest(BaseModel):
class ApprovalResult(BaseModel):
"""Result of human approval decision."""
decision: ApprovalDecision
reason: str | None = None
modifications: dict[str, Any] = Field(default_factory=dict)
@@ -71,10 +93,11 @@ class ApprovalResult(BaseModel):
class JudgmentAction(str, Enum):
"""Actions the judge can take after evaluating a step."""
ACCEPT = "accept" # Step completed successfully, continue
RETRY = "retry" # Retry the step with feedback
REPLAN = "replan" # Return to external planner for new plan
ESCALATE = "escalate" # Request human intervention
ACCEPT = "accept" # Step completed successfully, continue
RETRY = "retry" # Retry the step with feedback
REPLAN = "replan" # Return to external planner for new plan
ESCALATE = "escalate" # Request human intervention
class ActionSpec(BaseModel):
@@ -83,6 +106,7 @@ class ActionSpec(BaseModel):
This is the "what to do" part of a PlanStep.
"""
action_type: ActionType
# For LLM_CALL
@@ -114,6 +138,7 @@ class PlanStep(BaseModel):
Created by external planner, executed by Worker, evaluated by Judge.
"""
id: str
description: str
action: ActionSpec
@@ -121,27 +146,23 @@ class PlanStep(BaseModel):
# Data flow
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Input data for this step (can reference previous step outputs)"
description="Input data for this step (can reference previous step outputs)",
)
expected_outputs: list[str] = Field(
default_factory=list,
description="Keys this step should produce"
default_factory=list, description="Keys this step should produce"
)
# Dependencies
dependencies: list[str] = Field(
default_factory=list,
description="IDs of steps that must complete before this one"
default_factory=list, description="IDs of steps that must complete before this one"
)
# Human-in-the-loop (HITL)
requires_approval: bool = Field(
default=False,
description="If True, requires human approval before execution"
default=False, description="If True, requires human approval before execution"
)
approval_message: str | None = Field(
default=None,
description="Message to show human when requesting approval"
default=None, description="Message to show human when requesting approval"
)
# Execution state
@@ -157,11 +178,23 @@ class PlanStep(BaseModel):
model_config = {"extra": "allow"}
def is_ready(self, completed_step_ids: set[str]) -> bool:
"""Check if this step is ready to execute (all dependencies met)."""
def is_ready(self, terminal_step_ids: set[str]) -> bool:
"""Check if this step is ready to execute (all dependencies finished).
A step is ready when:
1. Its status is PENDING (not yet started)
2. All its dependencies are in a terminal state (completed, failed, skipped, or rejected)
Note: This allows dependent steps to become "ready" even if their dependencies
failed. The executor should check if any dependencies failed and handle
accordingly (e.g., skip the step or mark it as blocked).
Args:
terminal_step_ids: Set of step IDs that are in a terminal state
"""
if self.status != StepStatus.PENDING:
return False
return all(dep in completed_step_ids for dep in self.dependencies)
return all(dep in terminal_step_ids for dep in self.dependencies)
class Judgment(BaseModel):
@@ -170,6 +203,7 @@ class Judgment(BaseModel):
The Judge evaluates step results and decides what to do next.
"""
action: JudgmentAction
reasoning: str
feedback: str | None = None # For retry/replan - what went wrong
@@ -193,6 +227,7 @@ class EvaluationRule(BaseModel):
Rules are checked before falling back to LLM evaluation.
"""
id: str
description: str
@@ -216,6 +251,7 @@ class Plan(BaseModel):
Created by external planner (Claude Code, etc).
Executed by FlexibleGraphExecutor.
"""
id: str
goal_id: str
description: str
@@ -320,18 +356,46 @@ class Plan(BaseModel):
return None
def get_ready_steps(self) -> list[PlanStep]:
"""Get all steps that are ready to execute."""
completed_ids = {s.id for s in self.steps if s.status == StepStatus.COMPLETED}
return [s for s in self.steps if s.is_ready(completed_ids)]
"""Get all steps that are ready to execute.
A step is ready when all its dependencies are in terminal states
(completed, failed, skipped, or rejected).
"""
terminal_ids = {s.id for s in self.steps if s.status.is_terminal()}
return [s for s in self.steps if s.is_ready(terminal_ids)]
def get_completed_steps(self) -> list[PlanStep]:
"""Get all completed steps."""
return [s for s in self.steps if s.status == StepStatus.COMPLETED]
def is_complete(self) -> bool:
"""Check if all steps are completed."""
"""Check if all steps are in terminal states (finished executing).
Returns True when all steps have reached a terminal state, regardless
of whether they succeeded or failed. Use has_failed_steps() to check
if any steps failed.
"""
return all(s.status.is_terminal() for s in self.steps)
def is_successful(self) -> bool:
"""Check if all steps completed successfully."""
return all(s.status == StepStatus.COMPLETED for s in self.steps)
def has_failed_steps(self) -> bool:
"""Check if any steps failed, were skipped, or were rejected."""
return any(
s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
for s in self.steps
)
def get_failed_steps(self) -> list[PlanStep]:
"""Get all steps that failed, were skipped, or were rejected."""
return [
s
for s in self.steps
if s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
]
def to_feedback_context(self) -> dict[str, Any]:
"""Create context for replanning."""
return {
@@ -361,12 +425,13 @@ class Plan(BaseModel):
class ExecutionStatus(str, Enum):
"""Status of plan execution."""
COMPLETED = "completed"
AWAITING_APPROVAL = "awaiting_approval" # Paused for human approval
NEEDS_REPLAN = "needs_replan"
NEEDS_ESCALATION = "needs_escalation"
REJECTED = "rejected" # Human rejected a step
ABORTED = "aborted" # Human aborted execution
ABORTED = "aborted" # Human aborted execution
FAILED = "failed"
@@ -376,6 +441,7 @@ class PlanExecutionResult(BaseModel):
Returned to external planner with status and feedback.
"""
status: ExecutionStatus
# Results from completed steps
@@ -421,6 +487,7 @@ def load_export(data: str | dict) -> tuple["Plan", Any]:
result = await executor.execute_plan(plan, goal, context)
"""
import json as json_module
from framework.graph.goal import Goal
if isinstance(data, str):
+262
View File
@@ -0,0 +1,262 @@
import ast
import operator
from typing import Any
# Safe operators whitelist
SAFE_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.LShift: operator.lshift,
ast.RShift: operator.rshift,
ast.BitOr: operator.or_,
ast.BitXor: operator.xor,
ast.BitAnd: operator.and_,
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
ast.Is: operator.is_,
ast.IsNot: operator.is_not,
ast.In: lambda x, y: x in y,
ast.NotIn: lambda x, y: x not in y,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
ast.Not: operator.not_,
ast.Invert: operator.inv,
}
# Safe functions whitelist
SAFE_FUNCTIONS = {
"len": len,
"int": int,
"float": float,
"str": str,
"bool": bool,
"list": list,
"dict": dict,
"tuple": tuple,
"set": set,
"min": min,
"max": max,
"sum": sum,
"abs": abs,
"round": round,
"all": all,
"any": any,
}
class SafeEvalVisitor(ast.NodeVisitor):
def __init__(self, context: dict[str, Any]):
self.context = context
def visit(self, node: ast.AST) -> Any:
# Override visit to prevent default behavior and ensure only explicitly allowed nodes work
method = "visit_" + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node: ast.AST):
raise ValueError(f"Use of {node.__class__.__name__} is not allowed")
def visit_Expression(self, node: ast.Expression) -> Any:
return self.visit(node.body)
def visit_Expr(self, node: ast.Expr) -> Any:
return self.visit(node.value)
def visit_Constant(self, node: ast.Constant) -> Any:
return node.value
# --- Number/String/Bytes/NameConstant (Python < 3.8 compat if needed) ---
def visit_Num(self, node: ast.Num) -> Any:
return node.n
def visit_Str(self, node: ast.Str) -> Any:
return node.s
def visit_NameConstant(self, node: ast.NameConstant) -> Any:
return node.value
# --- Data Structures ---
def visit_List(self, node: ast.List) -> list:
return [self.visit(elt) for elt in node.elts]
def visit_Tuple(self, node: ast.Tuple) -> tuple:
return tuple(self.visit(elt) for elt in node.elts)
def visit_Dict(self, node: ast.Dict) -> dict:
return {
self.visit(k): self.visit(v)
for k, v in zip(node.keys, node.values, strict=False)
if k is not None
}
# --- Operations ---
def visit_BinOp(self, node: ast.BinOp) -> Any:
op_func = SAFE_OPERATORS.get(type(node.op))
if op_func is None:
raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
return op_func(self.visit(node.left), self.visit(node.right))
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
op_func = SAFE_OPERATORS.get(type(node.op))
if op_func is None:
raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
return op_func(self.visit(node.operand))
def visit_Compare(self, node: ast.Compare) -> Any:
left = self.visit(node.left)
for op, comparator in zip(node.ops, node.comparators, strict=False):
op_func = SAFE_OPERATORS.get(type(op))
if op_func is None:
raise ValueError(f"Operator {type(op).__name__} is not allowed")
right = self.visit(comparator)
if not op_func(left, right):
return False
left = right # Chain comparisons
return True
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
values = [self.visit(v) for v in node.values]
if isinstance(node.op, ast.And):
return all(values)
elif isinstance(node.op, ast.Or):
return any(values)
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
def visit_IfExp(self, node: ast.IfExp) -> Any:
# Ternary: true_val if test else false_val
if self.visit(node.test):
return self.visit(node.body)
else:
return self.visit(node.orelse)
# --- Variables and Attributes ---
def visit_Name(self, node: ast.Name) -> Any:
if isinstance(node.ctx, ast.Load):
if node.id in self.context:
return self.context[node.id]
raise NameError(f"Name '{node.id}' is not defined")
raise ValueError("Only reading variables is allowed")
def visit_Subscript(self, node: ast.Subscript) -> Any:
# value[slice]
val = self.visit(node.value)
idx = self.visit(node.slice)
return val[idx]
def visit_Attribute(self, node: ast.Attribute) -> Any:
# value.attr
# STIRCT CHECK: No access to private attributes (starting with _)
if node.attr.startswith("_"):
raise ValueError(f"Access to private attribute '{node.attr}' is not allowed")
val = self.visit(node.value)
# Safe attribute access: only allow if it's in the dict (if val is dict)
# or it's a safe property of a basic type?
# Actually, for flexibility, people often use dot access for dicts in these expressions.
# But standard Python dict doesn't support dot access.
# If val is a dict, Attribute access usually fails in Python unless wrapped.
# If the user context provides objects, we might want to allow attribute access.
# BUT we must be careful not to allow access to dangerous things like __class__ etc.
# The check starts_with("_") covers __class__, __init__, etc.
try:
return getattr(val, node.attr)
except AttributeError:
# Fallback: maybe it's a dict and they want dot access?
# (Only if we want to support that sugar, usually not standard python)
# Let's stick to standard python behavior + strict private check.
pass
raise AttributeError(f"Object has no attribute '{node.attr}'")
def visit_Call(self, node: ast.Call) -> Any:
# Only allow calling whitelisted functions
func = self.visit(node.func)
# Check if the function object itself is in our whitelist values
# This is tricky because `func` is the actual function object,
# but we also want to verify it came from a safe place.
# Easier: Check if node.func is a Name and that name is in SAFE_FUNCTIONS.
is_safe = False
if isinstance(node.func, ast.Name):
if node.func.id in SAFE_FUNCTIONS:
is_safe = True
# Also allow methods on objects if they are safe?
# E.g. "somestring".lower() or list.append() (if we allowed mutation, but we don't for now)
# For now, restrict to SAFE_FUNCTIONS whitelist for global calls and deny method calls
# unless we explicitly add safe methods.
# Allowing method calls on strings/lists (split, join, get) is commonly needed.
if isinstance(node.func, ast.Attribute):
# Method call.
# Allow basic safe methods?
# For security, start strict. Only helper functions.
# Re-visiting: User might want 'output.get("key")'.
method_name = node.func.attr
if method_name in [
"get",
"keys",
"values",
"items",
"lower",
"upper",
"strip",
"split",
]:
is_safe = True
if not is_safe and func not in SAFE_FUNCTIONS.values():
raise ValueError("Call to function/method is not allowed")
args = [self.visit(arg) for arg in node.args]
keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords}
return func(*args, **keywords)
def visit_Index(self, node: ast.Index) -> Any:
# Python < 3.9
return self.visit(node.value)
def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
"""
Safely evaluate a python expression string.
Args:
expr: The expression string to evaluate.
context: Dictionary of variables available in the expression.
Returns:
The result of the evaluation.
Raises:
ValueError: If unsafe operations or syntax are detected.
SyntaxError: If the expression is invalid Python.
"""
if context is None:
context = {}
# Add safe builtins to context
full_context = context.copy()
full_context.update(SAFE_FUNCTIONS)
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as e:
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
visitor = SafeEvalVisitor(full_context)
return visitor.visit(tree)
@@ -6,8 +6,9 @@ Demonstrates how OutputCleaner fixes the JSON parsing trap using llama-3.3-70b.
import json
import os
from framework.graph.output_cleaner import OutputCleaner, CleansingConfig
from framework.graph.node import NodeSpec
from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
from framework.llm.litellm import LiteLLMProvider
@@ -42,7 +43,10 @@ def test_cleaning_with_cerebras():
# Scenario 1: JSON parsing trap (entire response in one key)
print("\n--- Scenario 1: JSON Parsing Trap ---")
malformed_output = {
"recommendation": '{\n "approval_decision": "APPROVED",\n "risk_score": 3.5,\n "reason": "Standard terms, low risk"\n}',
"recommendation": (
'{\n "approval_decision": "APPROVED",\n "risk_score": 3.5,\n '
'"reason": "Standard terms, low risk"\n}'
),
}
target_spec = NodeSpec(
@@ -84,14 +88,17 @@ def test_cleaning_with_cerebras():
print(json.dumps(cleaned, indent=2))
assert isinstance(cleaned, dict), "Should return dict"
assert "approval_decision" in str(cleaned) or isinstance(
cleaned.get("recommendation"), dict
), "Should have recommendation structure"
assert "approval_decision" in str(cleaned) or isinstance(cleaned.get("recommendation"), dict), (
"Should have recommendation structure"
)
# Scenario 2: Multiple keys with JSON string
print("\n\n--- Scenario 2: Multiple Keys, JSON String ---")
malformed_output2 = {
"analysis": '{"high_risk_clauses": ["unlimited liability"], "compliance_issues": [], "category": "high-risk"}',
"analysis": (
'{"high_risk_clauses": ["unlimited liability"], '
'"compliance_issues": [], "category": "high-risk"}'
),
"risk_score": "7.5", # String instead of number
}
@@ -131,9 +138,7 @@ def test_cleaning_with_cerebras():
assert isinstance(cleaned2, dict), "Should return dict"
assert isinstance(cleaned2.get("analysis"), dict), "analysis should be dict"
assert isinstance(
cleaned2.get("risk_score"), (int, float)
), "risk_score should be number"
assert isinstance(cleaned2.get("risk_score"), (int, float)), "risk_score should be number"
# Stats
stats = cleaner.get_stats()
+144 -13
View File
@@ -8,12 +8,15 @@ import logging
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel, ValidationError
logger = logging.getLogger(__name__)
@dataclass
class ValidationResult:
"""Result of validating an output."""
success: bool
errors: list[str]
@@ -30,11 +33,76 @@ class OutputValidator:
Used by the executor to catch bad outputs before they pollute memory.
"""
def _contains_code_indicators(self, value: str) -> bool:
"""
Check for code patterns in a string using sampling for efficiency.
For strings under 10KB, checks the entire content.
For longer strings, samples at strategic positions to balance
performance with detection accuracy.
Args:
value: The string to check for code indicators
Returns:
True if code indicators are found, False otherwise
"""
code_indicators = [
# Python
"def ",
"class ",
"import ",
"from ",
"if __name__",
"async def ",
"await ",
"try:",
"except:",
# JavaScript/TypeScript
"function ",
"const ",
"let ",
"=> {",
"require(",
"export ",
# SQL
"SELECT ",
"INSERT ",
"UPDATE ",
"DELETE ",
"DROP ",
# HTML/Script injection
"<script",
"<?php",
"<%",
]
# For strings under 10KB, check the entire content
if len(value) < 10000:
return any(indicator in value for indicator in code_indicators)
# For longer strings, sample at strategic positions
sample_positions = [
0, # Start
len(value) // 4, # 25%
len(value) // 2, # 50%
3 * len(value) // 4, # 75%
max(0, len(value) - 2000), # Near end
]
for pos in sample_positions:
chunk = value[pos : pos + 2000]
if any(indicator in chunk for indicator in code_indicators):
return True
return False
def validate_output_keys(
self,
output: dict[str, Any],
expected_keys: list[str],
allow_empty: bool = False,
nullable_keys: list[str] | None = None,
) -> ValidationResult:
"""
Validate that all expected keys are present and non-empty.
@@ -43,16 +111,17 @@ class OutputValidator:
output: The output dict to validate
expected_keys: Keys that must be present
allow_empty: If True, allow empty string values
nullable_keys: Keys that are allowed to be None
Returns:
ValidationResult with success status and any errors
"""
errors = []
nullable_keys = nullable_keys or []
if not isinstance(output, dict):
return ValidationResult(
success=False,
errors=[f"Output is not a dict, got {type(output).__name__}"]
success=False, errors=[f"Output is not a dict, got {type(output).__name__}"]
)
for key in expected_keys:
@@ -61,12 +130,78 @@ class OutputValidator:
elif not allow_empty:
value = output[key]
if value is None:
errors.append(f"Output key '{key}' is None")
if key not in nullable_keys:
errors.append(f"Output key '{key}' is None")
elif isinstance(value, str) and len(value.strip()) == 0:
errors.append(f"Output key '{key}' is empty string")
return ValidationResult(success=len(errors) == 0, errors=errors)
def validate_with_pydantic(
self,
output: dict[str, Any],
model: type[BaseModel],
) -> tuple[ValidationResult, BaseModel | None]:
"""
Validate output against a Pydantic model.
Args:
output: The output dict to validate
model: Pydantic model class to validate against
Returns:
Tuple of (ValidationResult, validated_model_instance or None)
"""
try:
validated = model.model_validate(output)
return ValidationResult(success=True, errors=[]), validated
except ValidationError as e:
errors = []
for error in e.errors():
field_path = ".".join(str(loc) for loc in error["loc"])
msg = error["msg"]
error_type = error["type"]
errors.append(f"{field_path}: {msg} (type: {error_type})")
return ValidationResult(success=False, errors=errors), None
def format_validation_feedback(
self,
validation_result: ValidationResult,
model: type[BaseModel],
) -> str:
"""
Format validation errors as feedback for LLM retry.
Args:
validation_result: The failed validation result
model: The Pydantic model that was used for validation
Returns:
Formatted feedback string to include in retry prompt
"""
# Get the model's JSON schema for reference
schema = model.model_json_schema()
feedback = "Your previous response had validation errors:\n\n"
feedback += "ERRORS:\n"
for error in validation_result.errors:
feedback += f" - {error}\n"
feedback += "\nEXPECTED SCHEMA:\n"
feedback += f" Model: {model.__name__}\n"
if "properties" in schema:
feedback += " Required fields:\n"
required = schema.get("required", [])
for prop_name, prop_info in schema["properties"].items():
req_marker = " (required)" if prop_name in required else ""
prop_type = prop_info.get("type", "any")
feedback += f" - {prop_name}: {prop_type}{req_marker}\n"
feedback += "\nPlease fix the errors and respond with valid JSON matching the schema."
return feedback
def validate_no_hallucination(
self,
output: dict[str, Any],
@@ -93,16 +228,10 @@ class OutputValidator:
if not isinstance(value, str):
continue
# Check for Python-like code
code_indicators = [
"def ", "class ", "import ", "from ", "if __name__",
"async def ", "await ", "try:", "except:"
]
if any(indicator in value[:500] for indicator in code_indicators):
# Check for code patterns in the entire string, not just first 500 chars
if self._contains_code_indicators(value):
# Could be legitimate, but warn
logger.warning(
f"Output key '{key}' may contain code - verify this is expected"
)
logger.warning(f"Output key '{key}' may contain code - verify this is expected")
# Check for overly long values
if len(value) > max_length:
@@ -148,6 +277,7 @@ class OutputValidator:
expected_keys: list[str] | None = None,
schema: dict[str, Any] | None = None,
check_hallucination: bool = True,
nullable_keys: list[str] | None = None,
) -> ValidationResult:
"""
Run all applicable validations on output.
@@ -157,6 +287,7 @@ class OutputValidator:
expected_keys: Optional list of required keys
schema: Optional JSON schema
check_hallucination: Whether to check for hallucination patterns
nullable_keys: Keys that are allowed to be None
Returns:
Combined ValidationResult
@@ -165,7 +296,7 @@ class OutputValidator:
# Validate keys if provided
if expected_keys:
result = self.validate_output_keys(output, expected_keys)
result = self.validate_output_keys(output, expected_keys, nullable_keys=nullable_keys)
all_errors.extend(result.errors)
# Validate schema if provided
+38 -19
View File
@@ -10,20 +10,24 @@ appropriate executor based on action type:
- Code execution (sandboxed)
"""
from typing import Any, Callable
from dataclasses import dataclass, field
import time
import json
import logging
import re
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from framework.graph.code_sandbox import CodeSandbox
from framework.graph.plan import (
PlanStep,
ActionSpec,
ActionType,
PlanStep,
)
from framework.graph.code_sandbox import CodeSandbox
from framework.runtime.core import Runtime
from framework.llm.provider import LLMProvider, Tool
from framework.runtime.core import Runtime
logger = logging.getLogger(__name__)
def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
@@ -50,7 +54,7 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
# Try to extract JSON from markdown code blocks
# Pattern: ```json ... ``` or ``` ... ```
code_block_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
matches = re.findall(code_block_pattern, cleaned)
if matches:
@@ -59,34 +63,46 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
try:
parsed = json.loads(match.strip())
return parsed, match.strip()
except json.JSONDecodeError:
except json.JSONDecodeError as e:
logger.debug(
f"Failed to parse JSON from code block: {e}. "
f"Content preview: {match.strip()[:100]}..."
)
continue
# No code blocks or parsing failed - try parsing the whole response
try:
parsed = json.loads(cleaned)
return parsed, cleaned
except json.JSONDecodeError:
pass
except json.JSONDecodeError as e:
logger.debug(
f"Failed to parse entire response as JSON: {e}. Content preview: {cleaned[:100]}..."
)
# Try to find JSON-like content (starts with { or [)
json_start_pattern = r'(\{[\s\S]*\}|\[[\s\S]*\])'
json_start_pattern = r"(\{[\s\S]*\}|\[[\s\S]*\])"
json_matches = re.findall(json_start_pattern, cleaned)
for match in json_matches:
try:
parsed = json.loads(match)
return parsed, match
except json.JSONDecodeError:
except json.JSONDecodeError as e:
logger.debug(f"Failed to parse JSON pattern: {e}. Content preview: {match[:100]}...")
continue
# Could not parse as JSON
# Could not parse as JSON - log warning
logger.warning(
f"Could not parse LLM response as JSON after trying all strategies. "
f"Response preview: {cleaned[:200]}..."
)
return None, cleaned
@dataclass
class StepExecutionResult:
"""Result of executing a plan step."""
success: bool
outputs: dict[str, Any] = field(default_factory=dict)
error: str | None = None
@@ -160,11 +176,13 @@ class WorkerNode:
# Record decision
decision_id = self.runtime.decide(
intent=f"Execute plan step: {step.description}",
options=[{
"id": step.action.action_type.value,
"description": f"Execute {step.action.action_type.value} action",
"action_type": step.action.action_type.value,
}],
options=[
{
"id": step.action.action_type.value,
"description": f"Execute {step.action.action_type.value} action",
"action_type": step.action.action_type.value,
}
],
chosen=step.action.action_type.value,
reasoning=f"Step requires {step.action.action_type.value}",
context={"step_id": step.id, "inputs": step.inputs},
@@ -288,7 +306,7 @@ class WorkerNode:
if inputs:
context_section = "\n\n--- Context Data ---\n"
for key, value in inputs.items():
if isinstance(value, (dict, list)):
if isinstance(value, dict | list):
context_section += f"{key}: {json.dumps(value, indent=2)}\n"
else:
context_section += f"{key}: {value}\n"
@@ -414,6 +432,7 @@ class WorkerNode:
try:
# Execute tool via formal executor
from framework.llm.provider import ToolUse
tool_use = ToolUse(
id=f"step_{tool_name}",
name=tool_name,
+22 -3
View File
@@ -1,7 +1,26 @@
"""LLM provider abstraction."""
from framework.llm.provider import LLMProvider, LLMResponse
from framework.llm.anthropic import AnthropicProvider
from framework.llm.litellm import LiteLLMProvider
__all__ = ["LLMProvider", "LLMResponse", "AnthropicProvider", "LiteLLMProvider"]
__all__ = ["LLMProvider", "LLMResponse"]
try:
from framework.llm.anthropic import AnthropicProvider # noqa: F401
__all__.append("AnthropicProvider")
except ImportError:
pass
try:
from framework.llm.litellm import LiteLLMProvider # noqa: F401
__all__.append("LiteLLMProvider")
except ImportError:
pass
try:
from framework.llm.mock import MockLLMProvider # noqa: F401
__all__.append("MockLLMProvider")
except ImportError:
pass
+4 -3
View File
@@ -1,10 +1,11 @@
"""Anthropic Claude LLM provider - backward compatible wrapper around LiteLLM."""
import os
from collections.abc import Callable
from typing import Any
from framework.llm.provider import LLMProvider, LLMResponse, Tool
from framework.llm.litellm import LiteLLMProvider
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
def _get_api_key_from_credential_manager() -> str | None:
@@ -55,7 +56,7 @@ class AnthropicProvider(LLMProvider):
)
self.model = model
self._provider = LiteLLMProvider(
model=model,
api_key=self.api_key,
@@ -85,7 +86,7 @@ class AnthropicProvider(LLMProvider):
messages: list[dict[str, Any]],
system: str,
tools: list[Tool],
tool_executor: callable,
tool_executor: Callable[[ToolUse], ToolResult],
max_iterations: int = 10,
) -> LLMResponse:
"""Run a tool-use loop until Claude produces a final response (via LiteLLM)."""
+45 -29
View File
@@ -8,11 +8,15 @@ See: https://docs.litellm.ai/docs/providers
"""
import json
from collections.abc import Callable
from typing import Any
import litellm
try:
import litellm
except ImportError:
litellm = None # type: ignore[assignment]
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolUse
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
class LiteLLMProvider(LLMProvider):
@@ -23,6 +27,7 @@ class LiteLLMProvider(LLMProvider):
- OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
- Anthropic: claude-3-opus, claude-3-sonnet, claude-3-haiku
- Google: gemini-pro, gemini-1.5-pro, gemini-1.5-flash
- DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner
- Mistral: mistral-large, mistral-medium, mistral-small
- Groq: llama3-70b, mixtral-8x7b
- Local: ollama/llama3, ollama/mistral
@@ -38,6 +43,9 @@ class LiteLLMProvider(LLMProvider):
# Google Gemini
provider = LiteLLMProvider(model="gemini/gemini-1.5-flash")
# DeepSeek
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
# Local Ollama
provider = LiteLLMProvider(model="ollama/llama3")
@@ -72,6 +80,11 @@ class LiteLLMProvider(LLMProvider):
self.api_base = api_base
self.extra_kwargs = kwargs
if litellm is None:
raise ImportError(
"LiteLLM is not installed. Please install it with: pip install litellm"
)
def complete(
self,
messages: list[dict[str, Any]],
@@ -90,9 +103,7 @@ class LiteLLMProvider(LLMProvider):
# Add JSON mode via prompt engineering (works across all providers)
if json_mode:
json_instruction = (
"\n\nPlease respond with a valid JSON object."
)
json_instruction = "\n\nPlease respond with a valid JSON object."
# Append to system message if present, otherwise add as system message
if full_messages and full_messages[0]["role"] == "system":
full_messages[0]["content"] += json_instruction
@@ -122,7 +133,7 @@ class LiteLLMProvider(LLMProvider):
kwargs["response_format"] = response_format
# Make the call
response = litellm.completion(**kwargs)
response = litellm.completion(**kwargs) # type: ignore[union-attr]
# Extract content
content = response.choices[0].message.content or ""
@@ -146,8 +157,9 @@ class LiteLLMProvider(LLMProvider):
messages: list[dict[str, Any]],
system: str,
tools: list[Tool],
tool_executor: callable,
tool_executor: Callable[[ToolUse], ToolResult],
max_iterations: int = 10,
max_tokens: int = 4096,
) -> LLMResponse:
"""Run a tool-use loop until the LLM produces a final response."""
# Prepare messages with system prompt
@@ -167,7 +179,7 @@ class LiteLLMProvider(LLMProvider):
kwargs: dict[str, Any] = {
"model": self.model,
"messages": current_messages,
"max_tokens": 1024,
"max_tokens": max_tokens,
"tools": openai_tools,
**self.extra_kwargs,
}
@@ -177,7 +189,7 @@ class LiteLLMProvider(LLMProvider):
if self.api_base:
kwargs["api_base"] = self.api_base
response = litellm.completion(**kwargs)
response = litellm.completion(**kwargs) # type: ignore[union-attr]
# Track tokens
usage = response.usage
@@ -201,21 +213,23 @@ class LiteLLMProvider(LLMProvider):
# Process tool calls.
# Add assistant message with tool calls.
current_messages.append({
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in message.tool_calls
],
})
current_messages.append(
{
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in message.tool_calls
],
}
)
# Execute tools and add results.
for tool_call in message.tool_calls:
@@ -234,11 +248,13 @@ class LiteLLMProvider(LLMProvider):
result = tool_executor(tool_use)
# Add tool result message
current_messages.append({
"role": "tool",
"tool_call_id": result.tool_use_id,
"content": result.content,
})
current_messages.append(
{
"role": "tool",
"tool_call_id": result.tool_use_id,
"content": result.content,
}
)
# Max iterations reached
return LLMResponse(
+177
View File
@@ -0,0 +1,177 @@
"""Mock LLM Provider for testing and structural validation without real LLM calls."""
import json
import re
from collections.abc import Callable
from typing import Any
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
class MockLLMProvider(LLMProvider):
"""
Mock LLM provider for testing agents without making real API calls.
This provider generates placeholder responses based on the expected output structure,
allowing structural validation and graph execution testing without incurring costs
or requiring API keys.
Example:
llm = MockLLMProvider()
response = llm.complete(
messages=[{"role": "user", "content": "test"}],
system="Generate JSON with keys: name, age",
json_mode=True
)
# Returns: {"name": "mock_value", "age": "mock_value"}
"""
def __init__(self, model: str = "mock-model"):
"""
Initialize the mock LLM provider.
Args:
model: Model name to report in responses (default: "mock-model")
"""
self.model = model
def _extract_output_keys(self, system: str) -> list[str]:
"""
Extract expected output keys from the system prompt.
Looks for patterns like:
- "output_keys: [key1, key2]"
- "keys: key1, key2"
- "Generate JSON with keys: key1, key2"
Args:
system: System prompt text
Returns:
List of extracted key names
"""
keys = []
# Pattern 1: output_keys: [key1, key2]
match = re.search(r"output_keys:\s*\[(.*?)\]", system, re.IGNORECASE)
if match:
keys_str = match.group(1)
keys = [k.strip().strip("\"'") for k in keys_str.split(",")]
return keys
# Pattern 2: "keys: key1, key2" or "Generate JSON with keys: key1, key2"
match = re.search(r"(?:keys|with keys):\s*([a-zA-Z0-9_,\s]+)", system, re.IGNORECASE)
if match:
keys_str = match.group(1)
keys = [k.strip() for k in keys_str.split(",") if k.strip()]
return keys
# Pattern 3: Look for JSON schema in system prompt
match = re.search(r'\{[^}]*"([a-zA-Z0-9_]+)":\s*', system)
if match:
# Found at least one key in a JSON-like structure
all_matches = re.findall(r'"([a-zA-Z0-9_]+)":\s*', system)
if all_matches:
return list(set(all_matches))
return keys
def _generate_mock_response(
self,
system: str = "",
json_mode: bool = False,
) -> str:
"""
Generate a mock response based on the system prompt and mode.
Args:
system: System prompt (may contain output key hints)
json_mode: If True, generate JSON response
Returns:
Mock response string
"""
if json_mode:
# Try to extract expected keys from system prompt
keys = self._extract_output_keys(system)
if keys:
# Generate JSON with the expected keys
mock_data = {key: f"mock_{key}_value" for key in keys}
return json.dumps(mock_data, indent=2)
else:
# Fallback: generic mock response
return json.dumps({"result": "mock_result_value"}, indent=2)
else:
# Plain text mock response
return "This is a mock response for testing purposes."
def complete(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 1024,
response_format: dict[str, Any] | None = None,
json_mode: bool = False,
) -> LLMResponse:
"""
Generate a mock completion without calling a real LLM.
Args:
messages: Conversation history (ignored in mock mode)
system: System prompt (used to extract expected output keys)
tools: Available tools (ignored in mock mode)
max_tokens: Maximum tokens (ignored in mock mode)
response_format: Response format (ignored in mock mode)
json_mode: If True, generate JSON response
Returns:
LLMResponse with mock content
"""
content = self._generate_mock_response(system=system, json_mode=json_mode)
return LLMResponse(
content=content,
model=self.model,
input_tokens=0,
output_tokens=0,
stop_reason="mock_complete",
)
def complete_with_tools(
self,
messages: list[dict[str, Any]],
system: str,
tools: list[Tool],
tool_executor: Callable[[ToolUse], ToolResult],
max_iterations: int = 10,
) -> LLMResponse:
"""
Generate a mock completion without tool use.
In mock mode, we skip tool execution and return a final response immediately.
Args:
messages: Initial conversation (ignored in mock mode)
system: System prompt (used to extract expected output keys)
tools: Available tools (ignored in mock mode)
tool_executor: Tool executor function (ignored in mock mode)
max_iterations: Max iterations (ignored in mock mode)
Returns:
LLMResponse with mock content
"""
# In mock mode, we don't execute tools - just return a final response
# Try to generate JSON if the system prompt suggests structured output
json_mode = "json" in system.lower() or "output_keys" in system.lower()
content = self._generate_mock_response(system=system, json_mode=json_mode)
return LLMResponse(
content=content,
model=self.model,
input_tokens=0,
output_tokens=0,
stop_reason="mock_complete",
)
+6 -1
View File
@@ -1,6 +1,7 @@
"""LLM Provider abstraction for pluggable LLM backends."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
@@ -8,6 +9,7 @@ from typing import Any
@dataclass
class LLMResponse:
"""Response from an LLM call."""
content: str
model: str
input_tokens: int = 0
@@ -19,6 +21,7 @@ class LLMResponse:
@dataclass
class Tool:
"""A tool the LLM can use."""
name: str
description: str
parameters: dict[str, Any] = field(default_factory=dict)
@@ -27,6 +30,7 @@ class Tool:
@dataclass
class ToolUse:
"""A tool call requested by the LLM."""
id: str
name: str
input: dict[str, Any]
@@ -35,6 +39,7 @@ class ToolUse:
@dataclass
class ToolResult:
"""Result of executing a tool."""
tool_use_id: str
content: str
is_error: bool = False
@@ -86,7 +91,7 @@ class LLMProvider(ABC):
messages: list[dict[str, Any]],
system: str,
tools: list[Tool],
tool_executor: callable,
tool_executor: Callable[["ToolUse"], "ToolResult"],
max_iterations: int = 10,
) -> LLMResponse:
"""
File diff suppressed because it is too large Load Diff
+3 -3
View File
@@ -1,15 +1,15 @@
"""Agent Runner - load and run exported agents."""
from framework.runner.runner import AgentRunner, AgentInfo, ValidationResult
from framework.runner.tool_registry import ToolRegistry, tool
from framework.runner.orchestrator import AgentOrchestrator
from framework.runner.protocol import (
AgentMessage,
MessageType,
CapabilityLevel,
CapabilityResponse,
MessageType,
OrchestratorResult,
)
from framework.runner.runner import AgentInfo, AgentRunner, ValidationResult
from framework.runner.tool_registry import ToolRegistry, tool
__all__ = [
# Single agent
+103 -59
View File
@@ -22,12 +22,14 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
help="Path to agent folder (containing agent.json)",
)
run_parser.add_argument(
"--input", "-i",
"--input",
"-i",
type=str,
help="Input context as JSON string",
)
run_parser.add_argument(
"--input-file", "-f",
"--input-file",
"-f",
type=str,
help="Input context from JSON file",
)
@@ -37,17 +39,20 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
help="Run in mock mode (no real LLM calls)",
)
run_parser.add_argument(
"--output", "-o",
"--output",
"-o",
type=str,
help="Write results to file instead of stdout",
)
run_parser.add_argument(
"--quiet", "-q",
"--quiet",
"-q",
action="store_true",
help="Only output the final result JSON",
)
run_parser.add_argument(
"--verbose", "-v",
"--verbose",
"-v",
action="store_true",
help="Show detailed execution logs (steps, LLM calls, etc.)",
)
@@ -113,7 +118,8 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
help="Directory containing agent folders (default: exports)",
)
dispatch_parser.add_argument(
"--input", "-i",
"--input",
"-i",
type=str,
required=True,
help="Input context as JSON string",
@@ -124,13 +130,15 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
help="Description of what you want to accomplish",
)
dispatch_parser.add_argument(
"--agents", "-a",
"--agents",
"-a",
type=str,
nargs="+",
help="Specific agent names to use (default: all in directory)",
)
dispatch_parser.add_argument(
"--quiet", "-q",
"--quiet",
"-q",
action="store_true",
help="Only output the final result JSON",
)
@@ -170,15 +178,16 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
def cmd_run(args: argparse.Namespace) -> int:
"""Run an exported agent."""
import logging
from framework.runner import AgentRunner
# Set logging level (quiet by default for cleaner output)
if args.quiet:
logging.basicConfig(level=logging.ERROR, format='%(message)s')
elif getattr(args, 'verbose', False):
logging.basicConfig(level=logging.INFO, format='%(message)s')
logging.basicConfig(level=logging.ERROR, format="%(message)s")
elif getattr(args, "verbose", False):
logging.basicConfig(level=logging.INFO, format="%(message)s")
else:
logging.basicConfig(level=logging.WARNING, format='%(message)s')
logging.basicConfig(level=logging.WARNING, format="%(message)s")
# Load input context
context = {}
@@ -211,6 +220,7 @@ def cmd_run(args: argparse.Namespace) -> int:
entry_input_keys = runner.graph.nodes[0].input_keys if runner.graph.nodes else []
if "user_id" in entry_input_keys and context.get("user_id") is None:
import os
context["user_id"] = os.environ.get("USER", "default_user")
if not args.quiet:
@@ -279,7 +289,13 @@ def cmd_run(args: argparse.Namespace) -> int:
# If no meaningful key found, show all non-internal keys
if not shown:
for key, value in result.output.items():
if not key.startswith("_") and key not in ["user_id", "request", "memory_loaded", "user_profile", "recent_context"]:
if not key.startswith("_") and key not in [
"user_id",
"request",
"memory_loaded",
"user_profile",
"recent_context",
]:
if isinstance(value, (dict, list)):
print(f"\n{key}:")
value_str = json.dumps(value, indent=2, default=str)
@@ -311,19 +327,24 @@ def cmd_info(args: argparse.Namespace) -> int:
info = runner.info()
if args.json:
print(json.dumps({
"name": info.name,
"description": info.description,
"goal_name": info.goal_name,
"goal_description": info.goal_description,
"node_count": info.node_count,
"nodes": info.nodes,
"edges": info.edges,
"success_criteria": info.success_criteria,
"constraints": info.constraints,
"required_tools": info.required_tools,
"has_tools_module": info.has_tools_module,
}, indent=2))
print(
json.dumps(
{
"name": info.name,
"description": info.description,
"goal_name": info.goal_name,
"goal_description": info.goal_description,
"node_count": info.node_count,
"nodes": info.nodes,
"edges": info.edges,
"success_criteria": info.success_criteria,
"constraints": info.constraints,
"required_tools": info.required_tools,
"has_tools_module": info.has_tools_module,
},
indent=2,
)
)
else:
print(f"Agent: {info.name}")
print(f"Description: {info.description}")
@@ -333,8 +354,8 @@ def cmd_info(args: argparse.Namespace) -> int:
print()
print(f"Nodes ({info.node_count}):")
for node in info.nodes:
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get('input_keys') else ""
outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get('output_keys') else ""
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get("input_keys") else ""
outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else ""
print(f" - {node['id']}: {node['name']}{inputs}{outputs}")
print()
print(f"Success Criteria ({len(info.success_criteria)}):")
@@ -396,8 +417,9 @@ def cmd_list(args: argparse.Namespace) -> int:
directory = Path(args.directory)
if not directory.exists():
print(f"Directory not found: {directory}", file=sys.stderr)
return 1
# FIX: Handle missing directory gracefully on fresh install
print(f"No agents found in {directory}")
return 0
agents = []
for path in directory.iterdir():
@@ -405,19 +427,25 @@ def cmd_list(args: argparse.Namespace) -> int:
try:
runner = AgentRunner.load(path)
info = runner.info()
agents.append({
"path": str(path),
"name": info.name,
"description": info.description[:60] + "..." if len(info.description) > 60 else info.description,
"nodes": info.node_count,
"tools": len(info.required_tools),
})
agents.append(
{
"path": str(path),
"name": info.name,
"description": info.description[:60] + "..."
if len(info.description) > 60
else info.description,
"nodes": info.node_count,
"tools": len(info.required_tools),
}
)
runner.cleanup()
except Exception as e:
agents.append({
"path": str(path),
"error": str(e),
})
agents.append(
{
"path": str(path),
"error": str(e),
}
)
if not agents:
print(f"No agents found in {directory}")
@@ -540,7 +568,7 @@ def cmd_dispatch(args: argparse.Namespace) -> int:
def _interactive_approval(request):
"""Interactive approval callback for HITL mode."""
from framework.graph import ApprovalResult, ApprovalDecision
from framework.graph import ApprovalDecision, ApprovalResult
print()
print("=" * 60)
@@ -561,6 +589,7 @@ def _interactive_approval(request):
print(f"\n[{key}]:")
if isinstance(value, (dict, list)):
import json
value_str = json.dumps(value, indent=2, default=str)
# Show more content for approval - up to 2000 chars
if len(value_str) > 2000:
@@ -605,11 +634,14 @@ def _interactive_approval(request):
print("Invalid choice. Please enter a, r, s, or x.")
def _format_natural_language_to_json(user_input: str, input_keys: list[str], agent_description: str, session_context: dict = None) -> dict:
def _format_natural_language_to_json(
user_input: str, input_keys: list[str], agent_description: str, session_context: dict = None
) -> dict:
"""Use Haiku to convert natural language input to JSON based on agent's input schema."""
import anthropic
import os
import anthropic
client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
# Build prompt for Haiku
@@ -619,17 +651,22 @@ def _format_natural_language_to_json(user_input: str, input_keys: list[str], age
main_field = input_keys[0] if input_keys else "objective"
existing_value = session_context.get(main_field, "")
session_info = f"\n\nExisting {main_field}: \"{existing_value}\"\n\nThe user is providing ADDITIONAL information. Append this new information to the existing {main_field} to create an enriched, more detailed version."
session_info = (
f'\n\nExisting {main_field}: "{existing_value}"\n\n'
f"The user is providing ADDITIONAL information. Append this new "
f"information to the existing {main_field} to create an enriched, "
"more detailed version."
)
prompt = f"""You are formatting user input for an agent that requires specific input fields.
Agent: {agent_description}
Required input fields: {', '.join(input_keys)}{session_info}
Required input fields: {", ".join(input_keys)}{session_info}
User input: {user_input}
{"If this is a follow-up message, APPEND the new information to the existing field value to create a more complete, detailed version. Do not create new fields." if session_context else ""}
{"If this is a follow-up, APPEND new info to the existing field value." if session_context else ""}
Output ONLY valid JSON, no explanation:"""
@@ -637,7 +674,7 @@ Output ONLY valid JSON, no explanation:"""
message = client.messages.create(
model="claude-3-5-haiku-20241022", # Fast and cheap
max_tokens=500,
messages=[{"role": "user", "content": prompt}]
messages=[{"role": "user", "content": prompt}],
)
json_str = message.content[0].text.strip()
@@ -661,12 +698,13 @@ Output ONLY valid JSON, no explanation:"""
def cmd_shell(args: argparse.Namespace) -> int:
"""Start an interactive agent session."""
import logging
from framework.runner import AgentRunner
# Configure logging to show runtime visibility
logging.basicConfig(
level=logging.INFO,
format='%(message)s', # Simple format for clean output
format="%(message)s", # Simple format for clean output
)
agents_dir = Path(args.agents_dir)
@@ -690,7 +728,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
return 1
# Set up approval callback by default (unless --no-approve is set)
if not getattr(args, 'no_approve', False):
if not getattr(args, "no_approve", False):
runner.set_approval_callback(_interactive_approval)
print("\n🔔 Human-in-the-loop mode enabled")
print(" Steps marked for approval will pause for your review")
@@ -748,8 +786,10 @@ def cmd_shell(args: argparse.Namespace) -> int:
if user_input == "/nodes":
print("\nAgent nodes:")
for node in info.nodes:
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get('input_keys') else ""
outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get('output_keys') else ""
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get("input_keys") else ""
outputs = (
f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else ""
)
print(f" {node['id']}: {node['name']}{inputs}{outputs}")
print(f" {node['description']}")
print()
@@ -784,7 +824,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
user_input,
entry_input_keys,
info.description,
session_context=session_memory
session_context=session_memory,
)
print(f"✓ Formatted to: {json.dumps(context)}")
except Exception as e:
@@ -807,6 +847,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
# Auto-inject user_id if missing (for personal assistant agents)
if "user_id" in entry_input_keys and run_context.get("user_id") is None:
import os
run_context["user_id"] = os.environ.get("USER", "default_user")
# Add conversation history to context if agent expects it
@@ -872,12 +913,14 @@ def cmd_shell(args: argparse.Namespace) -> int:
session_memory[key] = value
# Track conversation history
conversation_history.append({
"input": context,
"output": result.output if result.output else {},
"status": "success" if result.success else "failed",
"paused_at": result.paused_at
})
conversation_history.append(
{
"input": context,
"output": result.output if result.output else {},
"status": "success" if result.success else "failed",
"paused_at": result.paused_at,
}
)
print()
@@ -904,6 +947,7 @@ def _select_agent(agents_dir: Path) -> str | None:
for i, agent_path in enumerate(agents, 1):
try:
from framework.runner import AgentRunner
runner = AgentRunner.load(agent_path)
info = runner.info()
desc = info.description[:50] + "..." if len(info.description) > 50 else info.description
+24 -13
View File
@@ -146,6 +146,7 @@ class MCPClient:
try:
import threading
from mcp import StdioServerParameters
# Create server parameters
@@ -180,7 +181,10 @@ class MCPClient:
# Create persistent stdio client context
self._stdio_context = stdio_client(server_params)
self._read_stream, self._write_stream = await self._stdio_context.__aenter__()
(
self._read_stream,
self._write_stream,
) = await self._stdio_context.__aenter__()
# Create persistent session
self._session = ClientSession(self._read_stream, self._write_stream)
@@ -215,7 +219,7 @@ class MCPClient:
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO (persistent)")
except Exception as e:
raise RuntimeError(f"Failed to connect to MCP server: {e}")
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
def _connect_http(self) -> None:
"""Connect to MCP server via HTTP transport."""
@@ -232,7 +236,9 @@ class MCPClient:
try:
response = self._http_client.get("/health")
response.raise_for_status()
logger.info(f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}")
logger.info(
f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}"
)
except Exception as e:
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
# Continue anyway, server might not have health endpoint
@@ -255,7 +261,10 @@ class MCPClient:
)
self._tools[tool.name] = tool
logger.info(f"Discovered {len(self._tools)} tools from '{self.config.name}': {list(self._tools.keys())}")
tool_names = list(self._tools.keys())
logger.info(
f"Discovered {len(self._tools)} tools from '{self.config.name}': {tool_names}"
)
except Exception as e:
logger.error(f"Failed to discover tools from '{self.config.name}': {e}")
raise
@@ -271,11 +280,13 @@ class MCPClient:
# Convert tools to dict format
tools_list = []
for tool in response.tools:
tools_list.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
})
tools_list.append(
{
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
}
)
return tools_list
@@ -303,7 +314,7 @@ class MCPClient:
return data.get("result", {}).get("tools", [])
except Exception as e:
raise RuntimeError(f"Failed to list tools via HTTP: {e}")
raise RuntimeError(f"Failed to list tools via HTTP: {e}") from e
def list_tools(self) -> list[MCPTool]:
"""
@@ -353,9 +364,9 @@ class MCPClient:
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, 'text'):
if hasattr(content_item, "text"):
return content_item.text
elif hasattr(content_item, 'data'):
elif hasattr(content_item, "data"):
return content_item.data
return result.content
@@ -387,7 +398,7 @@ class MCPClient:
return data.get("result", {}).get("content", [])
except Exception as e:
raise RuntimeError(f"Failed to call tool via HTTP: {e}")
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
def disconnect(self) -> None:
"""Disconnect from the MCP server."""
+7 -6
View File
@@ -72,6 +72,7 @@ class AgentOrchestrator:
# Auto-create LLM - LiteLLM auto-detects provider and API key from model name
if self._llm is None:
from framework.llm.litellm import LiteLLMProvider
self._llm = LiteLLMProvider(model=self._model)
def register(
@@ -205,7 +206,7 @@ class AgentOrchestrator:
responses = await asyncio.gather(*tasks, return_exceptions=True)
for agent_name, response in zip(routing.selected_agents, responses):
for agent_name, response in zip(routing.selected_agents, responses, strict=False):
if isinstance(response, Exception):
results[agent_name] = {"error": str(response)}
else:
@@ -326,7 +327,7 @@ class AgentOrchestrator:
results = await asyncio.gather(*tasks, return_exceptions=True)
for name, result in zip(agent_names, results):
for name, result in zip(agent_names, results, strict=False):
if isinstance(result, Exception):
responses[name] = AgentMessage(
type=MessageType.RESPONSE,
@@ -355,7 +356,7 @@ class AgentOrchestrator:
results = await asyncio.gather(*tasks, return_exceptions=True)
capabilities = {}
for name, result in zip(agent_names, results):
for name, result in zip(agent_names, results, strict=False):
if isinstance(result, Exception):
capabilities[name] = CapabilityResponse(
agent_name=name,
@@ -429,8 +430,7 @@ class AgentOrchestrator:
"""Use LLM to decide routing when multiple agents are capable."""
agents_info = "\n".join(
f"- {name}: {cap.reasoning} (confidence: {cap.confidence:.2f})"
for name, cap in capable
f"- {name}: {cap.reasoning} (confidence: {cap.confidence:.2f})" for name, cap in capable
)
prompt = f"""Multiple agents can handle this request. Decide the best routing.
@@ -463,7 +463,8 @@ Respond with JSON only:
)
import re
json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
selected = data.get("selected", [])
+1 -1
View File
@@ -1,10 +1,10 @@
"""Message protocol for multi-agent communication."""
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
import uuid
class MessageType(Enum):
+74 -46
View File
@@ -2,24 +2,25 @@
import json
import os
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Any
from typing import TYPE_CHECKING, Any
from framework.graph import Goal
from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec
from framework.graph.edge import AsyncEntryPointSpec, EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.executor import ExecutionResult, GraphExecutor
from framework.graph.node import NodeSpec
from framework.graph.executor import GraphExecutor, ExecutionResult
from framework.llm.provider import LLMProvider, Tool
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.core import Runtime
# Multi-entry-point runtime imports
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
from framework.runtime.core import Runtime
from framework.runtime.execution_stream import EntryPointSpec
if TYPE_CHECKING:
from framework.runner.protocol import CapabilityResponse, AgentMessage
from framework.runner.protocol import AgentMessage, CapabilityResponse
@dataclass
@@ -87,6 +88,7 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
"on_success": EdgeCondition.ON_SUCCESS,
"on_failure": EdgeCondition.ON_FAILURE,
"conditional": EdgeCondition.CONDITIONAL,
"llm_decide": EdgeCondition.LLM_DECIDE,
}
edge = EdgeSpec(
id=edge_data["id"],
@@ -102,16 +104,18 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
# Build AsyncEntryPointSpec objects for multi-entry-point support
async_entry_points = []
for aep_data in graph_data.get("async_entry_points", []):
async_entry_points.append(AsyncEntryPointSpec(
id=aep_data["id"],
name=aep_data.get("name", aep_data["id"]),
entry_node=aep_data["entry_node"],
trigger_type=aep_data.get("trigger_type", "manual"),
trigger_config=aep_data.get("trigger_config", {}),
isolation_level=aep_data.get("isolation_level", "shared"),
priority=aep_data.get("priority", 0),
max_concurrent=aep_data.get("max_concurrent", 10),
))
async_entry_points.append(
AsyncEntryPointSpec(
id=aep_data["id"],
name=aep_data.get("name", aep_data["id"]),
entry_node=aep_data["entry_node"],
trigger_type=aep_data.get("trigger_type", "manual"),
trigger_config=aep_data.get("trigger_config", {}),
isolation_level=aep_data.get("isolation_level", "shared"),
priority=aep_data.get("priority", 0),
max_concurrent=aep_data.get("max_concurrent", 10),
)
)
# Build GraphSpec
graph = GraphSpec(
@@ -131,27 +135,31 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
)
# Build Goal
from framework.graph.goal import SuccessCriterion, Constraint
from framework.graph.goal import Constraint, SuccessCriterion
success_criteria = []
for sc_data in goal_data.get("success_criteria", []):
success_criteria.append(SuccessCriterion(
id=sc_data["id"],
description=sc_data["description"],
metric=sc_data.get("metric", ""),
target=sc_data.get("target", ""),
weight=sc_data.get("weight", 1.0),
))
success_criteria.append(
SuccessCriterion(
id=sc_data["id"],
description=sc_data["description"],
metric=sc_data.get("metric", ""),
target=sc_data.get("target", ""),
weight=sc_data.get("weight", 1.0),
)
)
constraints = []
for c_data in goal_data.get("constraints", []):
constraints.append(Constraint(
id=c_data["id"],
description=c_data["description"],
constraint_type=c_data.get("constraint_type", "hard"),
category=c_data.get("category", "safety"),
check=c_data.get("check", ""),
))
constraints.append(
Constraint(
id=c_data["id"],
description=c_data["description"],
constraint_type=c_data.get("constraint_type", "hard"),
category=c_data.get("category", "safety"),
check=c_data.get("check", ""),
)
)
goal = Goal(
id=goal_data.get("id", ""),
@@ -379,7 +387,8 @@ class AgentRunner:
try:
self._tool_registry.register_mcp_server(server_config)
except Exception as e:
print(f"Warning: Failed to register MCP server '{server_config.get('name', 'unknown')}': {e}")
server_name = server_config.get("name", "unknown")
print(f"Warning: Failed to register MCP server '{server_name}': {e}")
except Exception as e:
print(f"Warning: Failed to load MCP servers config from {config_path}: {e}")
@@ -409,13 +418,19 @@ class AgentRunner:
session_id=session_id,
)
# Create LLM provider (if not mock mode and API key available)
# Create LLM provider
# Uses LiteLLM which auto-detects the provider from model name
if not self.mock_mode:
if self.mock_mode:
# Use mock LLM for testing without real API calls
from framework.llm.mock import MockLLMProvider
self._llm = MockLLMProvider(model=self.model)
else:
# Detect required API key from model name
api_key_env = self._get_api_key_env_var(self.model)
if api_key_env and os.environ.get(api_key_env):
from framework.llm.litellm import LiteLLMProvider
self._llm = LiteLLMProvider(model=self.model)
elif api_key_env:
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
@@ -760,7 +775,12 @@ class AgentRunner:
entry_node=self.graph.entry_node,
terminal_nodes=self.graph.terminal_nodes,
success_criteria=[
{"id": sc.id, "description": sc.description, "metric": sc.metric, "target": sc.target}
{
"id": sc.id,
"description": sc.description,
"metric": sc.metric,
"target": sc.target,
}
for sc in self.goal.success_criteria
],
constraints=[
@@ -810,7 +830,7 @@ class AgentRunner:
# Check tool credentials (Tier 2)
missing_creds = cred_manager.get_missing_for_tools(info.required_tools)
for cred_name, spec in missing_creds:
for _, spec in missing_creds:
missing_credentials.append(spec.env_var)
affected_tools = [t for t in info.required_tools if t in spec.tools]
tools_str = ", ".join(affected_tools)
@@ -820,9 +840,9 @@ class AgentRunner:
warnings.append(warning_msg)
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
node_types = list(set(node.node_type for node in self.graph.nodes))
node_types = list({node.node_type for node in self.graph.nodes})
missing_node_creds = cred_manager.get_missing_for_node_types(node_types)
for cred_name, spec in missing_node_creds:
for _, spec in missing_node_creds:
if spec.env_var not in missing_credentials: # Avoid duplicates
missing_credentials.append(spec.env_var)
affected_types = [t for t in node_types if t in spec.node_types]
@@ -834,11 +854,16 @@ class AgentRunner:
except ImportError:
# aden_tools not installed - fall back to direct check
has_llm_nodes = any(
node.node_type in ("llm_generate", "llm_tool_use")
for node in self.graph.nodes
node.node_type in ("llm_generate", "llm_tool_use") for node in self.graph.nodes
)
if has_llm_nodes and not os.environ.get("ANTHROPIC_API_KEY"):
warnings.append("Agent has LLM nodes but ANTHROPIC_API_KEY not set")
if has_llm_nodes:
api_key_env = self._get_api_key_env_var(self.model)
if api_key_env and not os.environ.get(api_key_env):
if api_key_env not in missing_credentials:
missing_credentials.append(api_key_env)
warnings.append(
f"Agent has LLM nodes but {api_key_env} not set (model: {self.model})"
)
return ValidationResult(
valid=len(errors) == 0,
@@ -848,7 +873,9 @@ class AgentRunner:
missing_credentials=missing_credentials,
)
async def can_handle(self, request: dict, llm: LLMProvider | None = None) -> "CapabilityResponse":
async def can_handle(
self, request: dict, llm: LLMProvider | None = None
) -> "CapabilityResponse":
"""
Ask the agent if it can handle this request.
@@ -861,7 +888,7 @@ class AgentRunner:
Returns:
CapabilityResponse with level, confidence, and reasoning
"""
from framework.runner.protocol import CapabilityResponse, CapabilityLevel
from framework.runner.protocol import CapabilityLevel, CapabilityResponse
# Use provided LLM or set up our own
eval_llm = llm
@@ -918,7 +945,8 @@ Respond with JSON only:
# Parse response
import re
json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
level_map = {
@@ -942,7 +970,7 @@ Respond with JSON only:
def _keyword_capability_check(self, request: dict) -> "CapabilityResponse":
"""Simple keyword-based capability check (fallback when no LLM)."""
from framework.runner.protocol import CapabilityResponse, CapabilityLevel
from framework.runner.protocol import CapabilityLevel, CapabilityResponse
info = self.info()
request_str = json.dumps(request).lower()
+4 -3
View File
@@ -4,11 +4,12 @@ import importlib.util
import inspect
import json
import logging
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
from typing import Any
from framework.llm.provider import Tool, ToolUse, ToolResult
from framework.llm.provider import Tool, ToolResult, ToolUse
logger = logging.getLogger(__name__)
@@ -142,7 +143,7 @@ class ToolRegistry:
# Check for TOOLS dict
if hasattr(module, "TOOLS"):
tools_dict = getattr(module, "TOOLS")
tools_dict = module.TOOLS
executor_func = getattr(module, "tool_executor", None)
for name, tool in tools_dict.items():
+15 -6
View File
@@ -7,15 +7,16 @@ while preserving the goal-driven approach.
import asyncio
import logging
from dataclasses import dataclass, field
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from framework.graph.executor import ExecutionResult
from framework.runtime.shared_state import SharedStateManager
from framework.runtime.outcome_aggregator import OutcomeAggregator
from framework.runtime.event_bus import EventBus
from framework.runtime.execution_stream import ExecutionStream, EntryPointSpec
from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream
from framework.runtime.outcome_aggregator import OutcomeAggregator
from framework.runtime.shared_state import SharedStateManager
from framework.storage.concurrent import ConcurrentStorage
if TYPE_CHECKING:
@@ -29,10 +30,13 @@ logger = logging.getLogger(__name__)
@dataclass
class AgentRuntimeConfig:
"""Configuration for AgentRuntime."""
max_concurrent_executions: int = 100
cache_ttl: float = 60.0
batch_interval: float = 0.1
max_history: int = 1000
execution_result_max: int = 1000
execution_result_ttl_seconds: float | None = None
class AgentRuntime:
@@ -206,6 +210,8 @@ class AgentRuntime:
llm=self._llm,
tools=self._tools,
tool_executor=self._tool_executor,
result_retention_max=self._config.execution_result_max,
result_retention_ttl_seconds=self._config.execution_result_ttl_seconds,
)
await stream.start()
self._streams[ep_id] = stream
@@ -285,7 +291,9 @@ class AgentRuntime:
ExecutionResult or None if timeout
"""
exec_id = await self.trigger(entry_point_id, input_data, session_state=session_state)
stream = self._streams[entry_point_id]
stream = self._streams.get(entry_point_id)
if stream is None:
raise ValueError(f"Entry point '{entry_point_id}' not found")
return await stream.wait_for_completion(exec_id, timeout)
async def get_goal_progress(self) -> dict[str, Any]:
@@ -411,6 +419,7 @@ class AgentRuntime:
# === CONVENIENCE FACTORY ===
def create_agent_runtime(
graph: "GraphSpec",
goal: "Goal",
+31 -22
View File
@@ -6,13 +6,14 @@ that Builder can analyze. The agent calls simple methods, and the runtime
handles all the structured logging.
"""
from datetime import datetime
from typing import Any
from pathlib import Path
import logging
import uuid
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import Any
from framework.schemas.decision import Decision, Option, Outcome, DecisionType
from framework.schemas.decision import Decision, DecisionType, Option, Outcome
from framework.schemas.run import Run, RunStatus
from framework.storage.backend import FileStorage
@@ -164,7 +165,7 @@ class Runtime:
context: Additional context available when deciding
Returns:
The decision ID (use this to record outcome later), or empty string if no run in progress
The decision ID (use to record outcome later), or empty string if no run
"""
if self._current_run is None:
# Gracefully handle case where run ended during exception handling
@@ -174,15 +175,17 @@ class Runtime:
# Build Option objects
option_objects = []
for opt in options:
option_objects.append(Option(
id=opt["id"],
description=opt.get("description", ""),
action_type=opt.get("action_type", "unknown"),
action_params=opt.get("action_params", {}),
pros=opt.get("pros", []),
cons=opt.get("cons", []),
confidence=opt.get("confidence", 0.5),
))
option_objects.append(
Option(
id=opt["id"],
description=opt.get("description", ""),
action_type=opt.get("action_type", "unknown"),
action_params=opt.get("action_params", {}),
pros=opt.get("pros", []),
cons=opt.get("cons", []),
confidence=opt.get("confidence", 0.5),
)
)
# Create decision
decision_id = f"dec_{len(self._current_run.decisions)}"
@@ -230,7 +233,9 @@ class Runtime:
if self._current_run is None:
# Gracefully handle case where run ended during exception handling
# This can happen in cascading error scenarios
logger.warning(f"record_outcome called but no run in progress (decision_id={decision_id})")
logger.warning(
f"record_outcome called but no run in progress (decision_id={decision_id})"
)
return
outcome = Outcome(
@@ -274,7 +279,9 @@ class Runtime:
if self._current_run is None:
# Gracefully handle case where run ended during exception handling
# Log the problem since we can't store it, then return empty ID
logger.warning(f"report_problem called but no run in progress: [{severity}] {description}")
logger.warning(
f"report_problem called but no run in progress: [{severity}] {description}"
)
return ""
return self._current_run.add_problem(
@@ -293,7 +300,7 @@ class Runtime:
options: list[dict[str, Any]],
chosen: str,
reasoning: str,
executor: callable,
executor: Callable,
**kwargs,
) -> tuple[str, Any]:
"""
@@ -370,11 +377,13 @@ class Runtime:
"""
return self.decide(
intent=intent,
options=[{
"id": "action",
"description": action,
"action_type": "execute",
}],
options=[
{
"id": "action",
"description": action,
"action_type": "execute",
}
],
chosen="action",
reasoning=reasoning,
node_id=node_id,
+67 -53
View File
@@ -9,11 +9,11 @@ Allows streams to:
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Awaitable, Callable
from typing import Any
logger = logging.getLogger(__name__)
@@ -48,6 +48,7 @@ class EventType(str, Enum):
@dataclass
class AgentEvent:
"""An event in the agent system."""
type: EventType
stream_id: str
execution_id: str | None = None
@@ -74,6 +75,7 @@ EventHandler = Callable[[AgentEvent], Awaitable[None]]
@dataclass
class Subscription:
"""A subscription to events."""
id: str
event_types: set[EventType]
handler: EventHandler
@@ -193,7 +195,7 @@ class EventBus:
async with self._lock:
self._event_history.append(event)
if len(self._event_history) > self._max_history:
self._event_history = self._event_history[-self._max_history:]
self._event_history = self._event_history[-self._max_history :]
# Find matching subscriptions
matching_handlers: list[EventHandler] = []
@@ -249,13 +251,15 @@ class EventBus:
correlation_id: str | None = None,
) -> None:
"""Emit execution started event."""
await self.publish(AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id=stream_id,
execution_id=execution_id,
data={"input": input_data or {}},
correlation_id=correlation_id,
))
await self.publish(
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id=stream_id,
execution_id=execution_id,
data={"input": input_data or {}},
correlation_id=correlation_id,
)
)
async def emit_execution_completed(
self,
@@ -265,13 +269,15 @@ class EventBus:
correlation_id: str | None = None,
) -> None:
"""Emit execution completed event."""
await self.publish(AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id=stream_id,
execution_id=execution_id,
data={"output": output or {}},
correlation_id=correlation_id,
))
await self.publish(
AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id=stream_id,
execution_id=execution_id,
data={"output": output or {}},
correlation_id=correlation_id,
)
)
async def emit_execution_failed(
self,
@@ -281,13 +287,15 @@ class EventBus:
correlation_id: str | None = None,
) -> None:
"""Emit execution failed event."""
await self.publish(AgentEvent(
type=EventType.EXECUTION_FAILED,
stream_id=stream_id,
execution_id=execution_id,
data={"error": error},
correlation_id=correlation_id,
))
await self.publish(
AgentEvent(
type=EventType.EXECUTION_FAILED,
stream_id=stream_id,
execution_id=execution_id,
data={"error": error},
correlation_id=correlation_id,
)
)
async def emit_goal_progress(
self,
@@ -296,14 +304,16 @@ class EventBus:
criteria_status: dict[str, Any],
) -> None:
"""Emit goal progress event."""
await self.publish(AgentEvent(
type=EventType.GOAL_PROGRESS,
stream_id=stream_id,
data={
"progress": progress,
"criteria_status": criteria_status,
},
))
await self.publish(
AgentEvent(
type=EventType.GOAL_PROGRESS,
stream_id=stream_id,
data={
"progress": progress,
"criteria_status": criteria_status,
},
)
)
async def emit_constraint_violation(
self,
@@ -313,15 +323,17 @@ class EventBus:
description: str,
) -> None:
"""Emit constraint violation event."""
await self.publish(AgentEvent(
type=EventType.CONSTRAINT_VIOLATION,
stream_id=stream_id,
execution_id=execution_id,
data={
"constraint_id": constraint_id,
"description": description,
},
))
await self.publish(
AgentEvent(
type=EventType.CONSTRAINT_VIOLATION,
stream_id=stream_id,
execution_id=execution_id,
data={
"constraint_id": constraint_id,
"description": description,
},
)
)
async def emit_state_changed(
self,
@@ -333,17 +345,19 @@ class EventBus:
scope: str,
) -> None:
"""Emit state changed event."""
await self.publish(AgentEvent(
type=EventType.STATE_CHANGED,
stream_id=stream_id,
execution_id=execution_id,
data={
"key": key,
"old_value": old_value,
"new_value": new_value,
"scope": scope,
},
))
await self.publish(
AgentEvent(
type=EventType.STATE_CHANGED,
stream_id=stream_id,
execution_id=execution_id,
data={
"key": key,
"old_value": old_value,
"new_value": new_value,
"scope": scope,
},
)
)
# === QUERY OPERATIONS ===
@@ -432,7 +446,7 @@ class EventBus:
if timeout:
try:
await asyncio.wait_for(event_received.wait(), timeout=timeout)
except asyncio.TimeoutError:
except TimeoutError:
return None
else:
await event_received.wait()
+83 -32
View File
@@ -9,22 +9,25 @@ Each stream has:
import asyncio
import logging
import time
import uuid
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Callable, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from framework.graph.executor import GraphExecutor, ExecutionResult
from framework.graph.executor import ExecutionResult, GraphExecutor
from framework.runtime.shared_state import IsolationLevel, SharedStateManager
from framework.runtime.stream_runtime import StreamRuntime, StreamRuntimeAdapter
from framework.runtime.shared_state import SharedStateManager, IsolationLevel, StreamMemory
if TYPE_CHECKING:
from framework.graph.edge import GraphSpec
from framework.graph.goal import Goal
from framework.storage.concurrent import ConcurrentStorage
from framework.runtime.outcome_aggregator import OutcomeAggregator
from framework.runtime.event_bus import EventBus
from framework.llm.provider import LLMProvider, Tool
from framework.runtime.event_bus import EventBus
from framework.runtime.outcome_aggregator import OutcomeAggregator
from framework.storage.concurrent import ConcurrentStorage
logger = logging.getLogger(__name__)
@@ -32,6 +35,7 @@ logger = logging.getLogger(__name__)
@dataclass
class EntryPointSpec:
"""Specification for an entry point."""
id: str
name: str
entry_node: str # Node ID to start from
@@ -49,6 +53,7 @@ class EntryPointSpec:
@dataclass
class ExecutionContext:
"""Context for a single execution."""
id: str
correlation_id: str
stream_id: str
@@ -105,6 +110,8 @@ class ExecutionStream:
llm: "LLMProvider | None" = None,
tools: list["Tool"] | None = None,
tool_executor: Callable | None = None,
result_retention_max: int | None = 1000,
result_retention_ttl_seconds: float | None = None,
):
"""
Initialize execution stream.
@@ -133,6 +140,8 @@ class ExecutionStream:
self._llm = llm
self._tools = tools or []
self._tool_executor = tool_executor
self._result_retention_max = result_retention_max
self._result_retention_ttl_seconds = result_retention_ttl_seconds
# Create stream-scoped runtime
self._runtime = StreamRuntime(
@@ -144,7 +153,8 @@ class ExecutionStream:
# Execution tracking
self._active_executions: dict[str, ExecutionContext] = {}
self._execution_tasks: dict[str, asyncio.Task] = {}
self._execution_results: dict[str, ExecutionResult] = {}
self._execution_results: OrderedDict[str, ExecutionResult] = OrderedDict()
self._execution_result_times: dict[str, float] = {}
self._completion_events: dict[str, asyncio.Event] = {}
# Concurrency control
@@ -164,12 +174,36 @@ class ExecutionStream:
# Emit stream started event
if self._event_bus:
from framework.runtime.event_bus import EventType, AgentEvent
await self._event_bus.publish(AgentEvent(
type=EventType.STREAM_STARTED,
stream_id=self.stream_id,
data={"entry_point": self.entry_spec.id},
))
from framework.runtime.event_bus import AgentEvent, EventType
await self._event_bus.publish(
AgentEvent(
type=EventType.STREAM_STARTED,
stream_id=self.stream_id,
data={"entry_point": self.entry_spec.id},
)
)
def _record_execution_result(self, execution_id: str, result: ExecutionResult) -> None:
"""Record a completed execution result with retention pruning."""
self._execution_results[execution_id] = result
self._execution_results.move_to_end(execution_id)
self._execution_result_times[execution_id] = time.time()
self._prune_execution_results()
def _prune_execution_results(self) -> None:
"""Prune completed results based on TTL and max retention."""
if self._result_retention_ttl_seconds is not None:
cutoff = time.time() - self._result_retention_ttl_seconds
for exec_id, recorded_at in list(self._execution_result_times.items()):
if recorded_at < cutoff:
self._execution_result_times.pop(exec_id, None)
self._execution_results.pop(exec_id, None)
if self._result_retention_max is not None:
while len(self._execution_results) > self._result_retention_max:
old_exec_id, _ = self._execution_results.popitem(last=False)
self._execution_result_times.pop(old_exec_id, None)
async def stop(self) -> None:
"""Stop the execution stream and cancel active executions."""
@@ -179,7 +213,7 @@ class ExecutionStream:
self._running = False
# Cancel all active executions
for exec_id, task in self._execution_tasks.items():
for _, task in self._execution_tasks.items():
if not task.done():
task.cancel()
try:
@@ -194,11 +228,14 @@ class ExecutionStream:
# Emit stream stopped event
if self._event_bus:
from framework.runtime.event_bus import EventType, AgentEvent
await self._event_bus.publish(AgentEvent(
type=EventType.STREAM_STOPPED,
stream_id=self.stream_id,
))
from framework.runtime.event_bus import AgentEvent, EventType
await self._event_bus.publish(
AgentEvent(
type=EventType.STREAM_STOPPED,
stream_id=self.stream_id,
)
)
async def execute(
self,
@@ -268,7 +305,7 @@ class ExecutionStream:
)
# Create execution-scoped memory
memory = self._state_manager.create_memory(
self._state_manager.create_memory(
execution_id=execution_id,
stream_id=self.stream_id,
isolation=ctx.isolation_level,
@@ -297,8 +334,8 @@ class ExecutionStream:
session_state=ctx.session_state,
)
# Store result
self._execution_results[execution_id] = result
# Store result with retention
self._record_execution_result(execution_id, result)
# Update context
ctx.completed_at = datetime.now()
@@ -333,10 +370,13 @@ class ExecutionStream:
ctx.status = "failed"
logger.error(f"Execution {execution_id} failed: {e}")
# Store error result
self._execution_results[execution_id] = ExecutionResult(
success=False,
error=str(e),
# Store error result with retention
self._record_execution_result(
execution_id,
ExecutionResult(
success=False,
error=str(e),
),
)
# Emit failure event
@@ -356,6 +396,12 @@ class ExecutionStream:
if execution_id in self._completion_events:
self._completion_events[execution_id].set()
# Remove in-flight bookkeeping
async with self._lock:
self._active_executions.pop(execution_id, None)
self._completion_events.pop(execution_id, None)
self._execution_tasks.pop(execution_id, None)
def _create_modified_graph(self) -> "GraphSpec":
"""Create a graph with the entry point overridden."""
# Use the existing graph but override entry_node
@@ -378,6 +424,7 @@ class ExecutionStream:
default_model=self.graph.default_model,
max_tokens=self.graph.max_tokens,
max_steps=self.graph.max_steps,
cleanup_llm_model=self.graph.cleanup_llm_model,
)
async def wait_for_completion(
@@ -398,6 +445,7 @@ class ExecutionStream:
event = self._completion_events.get(execution_id)
if event is None:
# Execution not found or already cleaned up
self._prune_execution_results()
return self._execution_results.get(execution_id)
try:
@@ -406,13 +454,15 @@ class ExecutionStream:
else:
await event.wait()
self._prune_execution_results()
return self._execution_results.get(execution_id)
except asyncio.TimeoutError:
except TimeoutError:
return None
def get_result(self, execution_id: str) -> ExecutionResult | None:
"""Get result of a completed execution."""
self._prune_execution_results()
return self._execution_results.get(execution_id)
def get_context(self, execution_id: str) -> ExecutionContext | None:
@@ -443,10 +493,7 @@ class ExecutionStream:
def get_active_count(self) -> int:
"""Get count of active executions."""
return len([
ctx for ctx in self._active_executions.values()
if ctx.status == "running"
])
return len([ctx for ctx in self._active_executions.values() if ctx.status == "running"])
def get_stats(self) -> dict:
"""Get stream statistics."""
@@ -454,6 +501,10 @@ class ExecutionStream:
for ctx in self._active_executions.values():
statuses[ctx.status] = statuses.get(ctx.status, 0) + 1
# Calculate available slots from running count instead of accessing private _value
running_count = statuses.get("running", 0)
available_slots = self.entry_spec.max_concurrent - running_count
return {
"stream_id": self.stream_id,
"entry_point": self.entry_spec.id,
@@ -462,5 +513,5 @@ class ExecutionStream:
"completed_executions": len(self._execution_results),
"status_counts": statuses,
"max_concurrent": self.entry_spec.max_concurrent,
"available_slots": self._semaphore._value,
"available_slots": available_slots,
}
+18 -13
View File
@@ -9,7 +9,7 @@ import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from framework.schemas.decision import Decision, Outcome
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
@dataclass
class CriterionStatus:
"""Status of a success criterion."""
criterion_id: str
description: str
met: bool
@@ -34,6 +35,7 @@ class CriterionStatus:
@dataclass
class ConstraintCheck:
"""Result of a constraint check."""
constraint_id: str
description: str
violated: bool
@@ -46,6 +48,7 @@ class ConstraintCheck:
@dataclass
class DecisionRecord:
"""Record of a decision for aggregation."""
stream_id: str
execution_id: str
decision: Decision
@@ -284,10 +287,11 @@ class OutcomeAggregator:
"successful_outcomes": self._successful_outcomes,
"failed_outcomes": self._failed_outcomes,
"success_rate": (
self._successful_outcomes / max(1, self._successful_outcomes + self._failed_outcomes)
self._successful_outcomes
/ max(1, self._successful_outcomes + self._failed_outcomes)
),
"streams_active": len(set(d.stream_id for d in self._decisions)),
"executions_total": len(set((d.stream_id, d.execution_id) for d in self._decisions)),
"streams_active": len({d.stream_id for d in self._decisions}),
"executions_total": len({(d.stream_id, d.execution_id) for d in self._decisions}),
}
# Determine recommendation
@@ -296,7 +300,7 @@ class OutcomeAggregator:
# Publish progress event
if self._event_bus:
# Get any stream ID for the event
stream_ids = set(d.stream_id for d in self._decisions)
stream_ids = {d.stream_id for d in self._decisions}
if stream_ids:
await self._event_bus.emit_goal_progress(
stream_id=list(stream_ids)[0],
@@ -323,7 +327,8 @@ class OutcomeAggregator:
# Get relevant decisions (those mentioning this criterion or related intents)
relevant_decisions = [
d for d in self._decisions
d
for d in self._decisions
if criterion.id in str(d.decision.active_constraints)
or self._is_related_to_criterion(d.decision, criterion)
]
@@ -341,7 +346,9 @@ class OutcomeAggregator:
# Add evidence
for d in relevant_decisions[:5]: # Limit evidence
if d.outcome:
evidence = f"{d.decision.intent}: {'success' if d.outcome.success else 'failed'}"
evidence = (
f"{d.decision.intent}: {'success' if d.outcome.success else 'failed'}"
)
status.evidence.append(evidence)
# Check if criterion is met based on target
@@ -373,10 +380,7 @@ class OutcomeAggregator:
violations = result["constraint_violations"]
# Check for hard constraint violations
hard_violations = [
v for v in violations
if self._is_hard_constraint(v["constraint_id"])
]
hard_violations = [v for v in violations if self._is_hard_constraint(v["constraint_id"])]
if hard_violations:
return "adjust" # Must address violations
@@ -409,7 +413,8 @@ class OutcomeAggregator:
) -> list[DecisionRecord]:
"""Get all decisions from a specific execution."""
return [
d for d in self._decisions
d
for d in self._decisions
if d.stream_id == stream_id and d.execution_id == execution_id
]
@@ -429,7 +434,7 @@ class OutcomeAggregator:
"failed_outcomes": self._failed_outcomes,
"constraint_violations": len(self._constraint_violations),
"criteria_tracked": len(self._criterion_status),
"streams_seen": len(set(d.stream_id for d in self._decisions)),
"streams_seen": len({d.stream_id for d in self._decisions}),
}
# === RESET OPERATIONS ===
+20 -15
View File
@@ -19,21 +19,24 @@ logger = logging.getLogger(__name__)
class IsolationLevel(str, Enum):
"""State isolation level for concurrent executions."""
ISOLATED = "isolated" # Private state per execution
SHARED = "shared" # Shared state (eventual consistency)
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
ISOLATED = "isolated" # Private state per execution
SHARED = "shared" # Shared state (eventual consistency)
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
class StateScope(str, Enum):
"""Scope for state operations."""
EXECUTION = "execution" # Local to a single execution
STREAM = "stream" # Shared within a stream
GLOBAL = "global" # Shared across all streams
EXECUTION = "execution" # Local to a single execution
STREAM = "stream" # Shared within a stream
GLOBAL = "global" # Shared across all streams
@dataclass
class StateChange:
"""Record of a state change."""
key: str
old_value: Any
new_value: Any
@@ -212,14 +215,16 @@ class SharedStateManager:
await self._write_direct(key, value, execution_id, stream_id, scope)
# Record change
self._record_change(StateChange(
key=key,
old_value=old_value,
new_value=value,
scope=scope,
execution_id=execution_id,
stream_id=stream_id,
))
self._record_change(
StateChange(
key=key,
old_value=old_value,
new_value=value,
scope=scope,
execution_id=execution_id,
stream_id=stream_id,
)
)
async def _write_direct(
self,
@@ -278,7 +283,7 @@ class SharedStateManager:
# Trim history if too long
if len(self._change_history) > self._max_history:
self._change_history = self._change_history[-self._max_history:]
self._change_history = self._change_history[-self._max_history :]
# === BULK OPERATIONS ===
+29 -19
View File
@@ -10,9 +10,9 @@ import asyncio
import logging
import uuid
from datetime import datetime
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from framework.schemas.decision import Decision, Option, Outcome, DecisionType
from framework.schemas.decision import Decision, DecisionType, Option, Outcome
from framework.schemas.run import Run, RunStatus
from framework.storage.concurrent import ConcurrentStorage
@@ -117,7 +117,8 @@ class StreamRuntime:
Returns:
The run ID
"""
run_id = f"run_{self.stream_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_id = f"run_{self.stream_id}_{timestamp}_{uuid.uuid4().hex[:8]}"
run = Run(
id=run_id,
@@ -130,7 +131,9 @@ class StreamRuntime:
self._run_locks[execution_id] = asyncio.Lock()
self._current_nodes[execution_id] = "unknown"
logger.debug(f"Started run {run_id} for execution {execution_id} in stream {self.stream_id}")
logger.debug(
f"Started run {run_id} for execution {execution_id} in stream {self.stream_id}"
)
return run_id
def end_run(
@@ -224,15 +227,17 @@ class StreamRuntime:
# Build Option objects
option_objects = []
for opt in options:
option_objects.append(Option(
id=opt["id"],
description=opt.get("description", ""),
action_type=opt.get("action_type", "unknown"),
action_params=opt.get("action_params", {}),
pros=opt.get("pros", []),
cons=opt.get("cons", []),
confidence=opt.get("confidence", 0.5),
))
option_objects.append(
Option(
id=opt["id"],
description=opt.get("description", ""),
action_type=opt.get("action_type", "unknown"),
action_params=opt.get("action_params", {}),
pros=opt.get("pros", []),
cons=opt.get("cons", []),
confidence=opt.get("confidence", 0.5),
)
)
# Create decision
decision_id = f"dec_{len(run.decisions)}"
@@ -341,7 +346,10 @@ class StreamRuntime:
"""
run = self._runs.get(execution_id)
if run is None:
logger.warning(f"report_problem called but no run for execution {execution_id}: [{severity}] {description}")
logger.warning(
f"report_problem called but no run for execution {execution_id}: "
f"[{severity}] {description}"
)
return ""
return run.add_problem(
@@ -377,11 +385,13 @@ class StreamRuntime:
return self.decide(
execution_id=execution_id,
intent=intent,
options=[{
"id": "action",
"description": action,
"action_type": "execute",
}],
options=[
{
"id": "action",
"description": action,
"action_type": "execute",
}
],
chosen="action",
reasoning=reasoning,
node_id=node_id,
@@ -11,24 +11,24 @@ Tests:
"""
import asyncio
import pytest
import tempfile
from pathlib import Path
from framework.graph import Goal
from framework.graph.goal import SuccessCriterion, Constraint
from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec
from framework.graph.node import NodeSpec
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
from framework.runtime.execution_stream import EntryPointSpec
from framework.runtime.shared_state import SharedStateManager, IsolationLevel
from framework.runtime.event_bus import EventBus, EventType, AgentEvent
from framework.runtime.outcome_aggregator import OutcomeAggregator
from framework.runtime.stream_runtime import StreamRuntime
import pytest
from framework.graph import Goal
from framework.graph.edge import AsyncEntryPointSpec, EdgeCondition, EdgeSpec, GraphSpec
from framework.graph.goal import Constraint, SuccessCriterion
from framework.graph.node import NodeSpec
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
from framework.runtime.execution_stream import EntryPointSpec
from framework.runtime.outcome_aggregator import OutcomeAggregator
from framework.runtime.shared_state import IsolationLevel, SharedStateManager
# === Test Fixtures ===
@pytest.fixture
def sample_goal():
"""Create a sample goal for testing."""
@@ -141,6 +141,7 @@ def temp_storage():
# === SharedStateManager Tests ===
class TestSharedStateManager:
"""Tests for SharedStateManager."""
@@ -175,8 +176,8 @@ class TestSharedStateManager:
"""Test shared state is visible across executions."""
manager = SharedStateManager()
mem1 = manager.create_memory("exec-1", "stream-1", IsolationLevel.SHARED)
mem2 = manager.create_memory("exec-2", "stream-1", IsolationLevel.SHARED)
manager.create_memory("exec-1", "stream-1", IsolationLevel.SHARED)
manager.create_memory("exec-2", "stream-1", IsolationLevel.SHARED)
# Write to global scope
await manager.write(
@@ -209,6 +210,7 @@ class TestSharedStateManager:
# === EventBus Tests ===
class TestEventBus:
"""Tests for EventBus pub/sub."""
@@ -226,12 +228,14 @@ class TestEventBus:
handler=handler,
)
await bus.publish(AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="webhook",
execution_id="exec-1",
data={"test": "data"},
))
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="webhook",
execution_id="exec-1",
data={"test": "data"},
)
)
# Allow handler to run
await asyncio.sleep(0.1)
@@ -256,16 +260,20 @@ class TestEventBus:
)
# Publish to webhook stream (should be received)
await bus.publish(AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="webhook",
))
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="webhook",
)
)
# Publish to api stream (should NOT be received)
await bus.publish(AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="api",
))
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="api",
)
)
await asyncio.sleep(0.1)
@@ -308,11 +316,13 @@ class TestEventBus:
# Publish the event
await asyncio.sleep(0.1)
await bus.publish(AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id="webhook",
execution_id="exec-1",
))
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id="webhook",
execution_id="exec-1",
)
)
event = await wait_task
@@ -322,6 +332,7 @@ class TestEventBus:
# === OutcomeAggregator Tests ===
class TestOutcomeAggregator:
"""Tests for OutcomeAggregator."""
@@ -376,6 +387,7 @@ class TestOutcomeAggregator:
# === AgentRuntime Tests ===
class TestAgentRuntime:
"""Tests for AgentRuntime orchestration."""
@@ -491,6 +503,7 @@ class TestAgentRuntime:
# === GraphSpec Validation Tests ===
class TestGraphSpecValidation:
"""Tests for GraphSpec with async_entry_points."""
@@ -595,6 +608,7 @@ class TestGraphSpecValidation:
# === Integration Tests ===
class TestCreateAgentRuntime:
"""Tests for the create_agent_runtime factory."""
+2 -2
View File
@@ -1,7 +1,7 @@
"""Schema definitions for runtime data."""
from framework.schemas.decision import Decision, Option, Outcome, DecisionEvaluation
from framework.schemas.run import Run, RunSummary, Problem
from framework.schemas.decision import Decision, DecisionEvaluation, Option, Outcome
from framework.schemas.run import Problem, Run, RunSummary
__all__ = [
"Decision",
+18 -13
View File
@@ -10,22 +10,23 @@ This is MORE important than actions because:
"""
from datetime import datetime
from typing import Any
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, computed_field
class DecisionType(str, Enum):
"""Types of decisions an agent can make."""
TOOL_SELECTION = "tool_selection" # Which tool to use
TOOL_SELECTION = "tool_selection" # Which tool to use
PARAMETER_CHOICE = "parameter_choice" # What parameters to pass
PATH_CHOICE = "path_choice" # Which branch to take
OUTPUT_FORMAT = "output_format" # How to format output
RETRY_STRATEGY = "retry_strategy" # How to handle failure
DELEGATION = "delegation" # Whether to delegate to another node
TERMINATION = "termination" # Whether to stop or continue
CUSTOM = "custom" # User-defined decision type
PATH_CHOICE = "path_choice" # Which branch to take
OUTPUT_FORMAT = "output_format" # How to format output
RETRY_STRATEGY = "retry_strategy" # How to handle failure
DELEGATION = "delegation" # Whether to delegate to another node
TERMINATION = "termination" # Whether to stop or continue
CUSTOM = "custom" # User-defined decision type
class Option(BaseModel):
@@ -35,9 +36,10 @@ class Option(BaseModel):
Capturing options is crucial - it shows what the agent considered
and enables us to evaluate whether the right choice was made.
"""
id: str
description: str # Human-readable: "Call search API"
action_type: str # "tool_call", "generate", "delegate"
description: str # Human-readable: "Call search API"
action_type: str # "tool_call", "generate", "delegate"
action_params: dict[str, Any] = Field(default_factory=dict)
# Why might this be good or bad?
@@ -57,9 +59,10 @@ class Outcome(BaseModel):
This is filled in AFTER the action completes, allowing us to
correlate decisions with their results.
"""
success: bool
result: Any = None # The actual output
error: str | None = None # Error message if failed
result: Any = None # The actual output
error: str | None = None # Error message if failed
# Side effects
state_changes: dict[str, Any] = Field(default_factory=dict)
@@ -67,7 +70,7 @@ class Outcome(BaseModel):
latency_ms: int = 0
# Natural language summary (crucial for Builder)
summary: str = "" # "Found 3 contacts matching query"
summary: str = "" # "Found 3 contacts matching query"
timestamp: datetime = Field(default_factory=datetime.now)
@@ -81,6 +84,7 @@ class DecisionEvaluation(BaseModel):
This is computed AFTER the run completes, allowing us to
judge decisions in light of their eventual outcomes.
"""
# Did it move toward the goal?
goal_aligned: bool = True
alignment_score: float = Field(default=1.0, ge=0.0, le=1.0)
@@ -109,6 +113,7 @@ class Decision(BaseModel):
Every significant choice the agent makes is captured here.
This is the core data structure for understanding and improving agents.
"""
id: str
timestamp: datetime = Field(default_factory=datetime.now)
node_id: str
+7 -2
View File
@@ -6,8 +6,8 @@ summaries and metrics that Builder needs to understand what happened.
"""
from datetime import datetime
from typing import Any
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, computed_field
@@ -16,10 +16,11 @@ from framework.schemas.decision import Decision, Outcome
class RunStatus(str, Enum):
"""Status of a run."""
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
STUCK = "stuck" # Making no progress
STUCK = "stuck" # Making no progress
CANCELLED = "cancelled"
@@ -29,6 +30,7 @@ class Problem(BaseModel):
Problems are surfaced explicitly so Builder can focus on what needs fixing.
"""
id: str
severity: str = Field(description="critical, warning, or minor")
description: str
@@ -42,6 +44,7 @@ class Problem(BaseModel):
class RunMetrics(BaseModel):
"""Quantitative metrics about a run."""
total_decisions: int = 0
successful_decisions: int = 0
failed_decisions: int = 0
@@ -68,6 +71,7 @@ class Run(BaseModel):
Contains all decisions, problems, and metrics from a single run.
"""
id: str
goal_id: str
started_at: datetime = Field(default_factory=datetime.now)
@@ -191,6 +195,7 @@ class RunSummary(BaseModel):
This is what I (Builder) want to see first when analyzing runs.
"""
run_id: str
goal_id: str
status: RunStatus
+40 -3
View File
@@ -8,7 +8,7 @@ Uses Pydantic's built-in serialization.
import json
from pathlib import Path
from framework.schemas.run import Run, RunSummary, RunStatus
from framework.schemas.run import Run, RunStatus, RunSummary
class FileStorage:
@@ -46,6 +46,40 @@ class FileStorage:
for d in dirs:
d.mkdir(parents=True, exist_ok=True)
def _validate_key(self, key: str) -> None:
"""
Validate key to prevent path traversal attacks.
Args:
key: The key to validate
Raises:
ValueError: If key contains path traversal or dangerous patterns
"""
if not key or key.strip() == "":
raise ValueError("Key cannot be empty")
# Block path separators
if "/" in key or "\\" in key:
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
# Block parent directory references
if ".." in key or key.startswith("."):
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
# Block absolute paths
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
# Block null bytes (Unix path injection)
if "\x00" in key:
raise ValueError("Invalid key format: null bytes not allowed")
# Block other dangerous special characters
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
if any(char in key for char in dangerous_chars):
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
# === RUN OPERATIONS ===
def save_run(self, run: Run) -> None:
@@ -140,6 +174,7 @@ class FileStorage:
def _get_index(self, index_type: str, key: str) -> list[str]:
"""Get values from an index."""
self._validate_key(key) # Prevent path traversal
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
if not index_path.exists():
return []
@@ -148,8 +183,9 @@ class FileStorage:
def _add_to_index(self, index_type: str, key: str, value: str) -> None:
"""Add a value to an index."""
self._validate_key(key) # Prevent path traversal
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
values = self._get_index(index_type, key)
values = self._get_index(index_type, key) # Already validated in _get_index
if value not in values:
values.append(value)
with open(index_path, "w") as f:
@@ -157,8 +193,9 @@ class FileStorage:
def _remove_from_index(self, index_type: str, key: str, value: str) -> None:
"""Remove a value from an index."""
self._validate_key(key) # Prevent path traversal
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
values = self._get_index(index_type, key)
values = self._get_index(index_type, key) # Already validated in _get_index
if value in values:
values.remove(value)
with open(index_path, "w") as f:

Some files were not shown because too many files have changed in this diff Show More