Merge remote-tracking branch 'origin/main' into conductorchicago
Resolved conflict in tools/pyproject.toml by keeping the expanded format with sql dependency from main. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"hooks": {
|
||||
"PostToolUse": [
|
||||
{
|
||||
"matcher": "Edit|Write|NotebookEdit",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "ruff check --fix \"$CLAUDE_FILE_PATH\" 2>/dev/null; ruff format \"$CLAUDE_FILE_PATH\" 2>/dev/null; true"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
+10
-3
@@ -108,8 +108,10 @@ async def _interactive_shell(verbose=False):
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
topic = await asyncio.get_event_loop().run_in_executor(None, input, "Topic> ")
|
||||
if topic.lower() in ['quit', 'exit', 'q']:
|
||||
topic = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Topic> "
|
||||
)
|
||||
if topic.lower() in ["quit", "exit", "q"]:
|
||||
click.echo("Goodbye!")
|
||||
break
|
||||
|
||||
@@ -130,7 +132,11 @@ async def _interactive_shell(verbose=False):
|
||||
click.echo(f"\nReport saved to: {output['file_path']}\n")
|
||||
if "final_report" in output:
|
||||
click.echo("\n--- Report Preview ---\n")
|
||||
preview = output["final_report"][:500] + "..." if len(output.get("final_report", "")) > 500 else output.get("final_report", "")
|
||||
preview = (
|
||||
output["final_report"][:500] + "..."
|
||||
if len(output.get("final_report", "")) > 500
|
||||
else output.get("final_report", "")
|
||||
)
|
||||
click.echo(preview)
|
||||
click.echo("\n")
|
||||
else:
|
||||
@@ -142,6 +148,7 @@ async def _interactive_shell(verbose=False):
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await agent.stop()
|
||||
|
||||
+39
-24
@@ -1,4 +1,5 @@
|
||||
"""Agent graph construction for Online Research Agent."""
|
||||
|
||||
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import ExecutionResult
|
||||
@@ -8,6 +9,16 @@ from framework.llm import LiteLLMProvider
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
from .config import default_config, metadata
|
||||
from .nodes import (
|
||||
parse_query_node,
|
||||
search_sources_node,
|
||||
fetch_content_node,
|
||||
evaluate_sources_node,
|
||||
synthesize_findings_node,
|
||||
write_report_node,
|
||||
quality_check_node,
|
||||
save_report_node,
|
||||
)
|
||||
|
||||
# Goal definition
|
||||
goal = Goal(
|
||||
@@ -78,17 +89,6 @@ goal = Goal(
|
||||
),
|
||||
],
|
||||
)
|
||||
# Import nodes
|
||||
from .nodes import (
|
||||
parse_query_node,
|
||||
search_sources_node,
|
||||
fetch_content_node,
|
||||
evaluate_sources_node,
|
||||
synthesize_findings_node,
|
||||
write_report_node,
|
||||
quality_check_node,
|
||||
save_report_node,
|
||||
)
|
||||
|
||||
# Node list
|
||||
nodes = [
|
||||
@@ -195,13 +195,15 @@ class OnlineResearchAgent:
|
||||
trigger_type = "manual"
|
||||
name = ep_id.replace("-", " ").title()
|
||||
|
||||
specs.append(EntryPointSpec(
|
||||
id=ep_id,
|
||||
name=name,
|
||||
entry_node=node_id,
|
||||
trigger_type=trigger_type,
|
||||
isolation_level="shared",
|
||||
))
|
||||
specs.append(
|
||||
EntryPointSpec(
|
||||
id=ep_id,
|
||||
name=name,
|
||||
entry_node=node_id,
|
||||
trigger_type=trigger_type,
|
||||
isolation_level="shared",
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
def _create_runtime(self, mock_mode=False) -> AgentRuntime:
|
||||
@@ -226,7 +228,10 @@ class OnlineResearchAgent:
|
||||
for server_name, server_config in mcp_servers.items():
|
||||
server_config["name"] = server_name
|
||||
# Resolve relative cwd paths
|
||||
if "cwd" in server_config and not Path(server_config["cwd"]).is_absolute():
|
||||
if (
|
||||
"cwd" in server_config
|
||||
and not Path(server_config["cwd"]).is_absolute()
|
||||
):
|
||||
server_config["cwd"] = str(agent_dir / server_config["cwd"])
|
||||
tool_registry.register_mcp_server(server_config)
|
||||
|
||||
@@ -298,7 +303,9 @@ class OnlineResearchAgent:
|
||||
"""
|
||||
if self._runtime is None or not self._runtime.is_running:
|
||||
raise RuntimeError("Agent runtime not started. Call start() first.")
|
||||
return await self._runtime.trigger(entry_point, input_data, correlation_id, session_state=session_state)
|
||||
return await self._runtime.trigger(
|
||||
entry_point, input_data, correlation_id, session_state=session_state
|
||||
)
|
||||
|
||||
async def trigger_and_wait(
|
||||
self,
|
||||
@@ -321,9 +328,13 @@ class OnlineResearchAgent:
|
||||
"""
|
||||
if self._runtime is None or not self._runtime.is_running:
|
||||
raise RuntimeError("Agent runtime not started. Call start() first.")
|
||||
return await self._runtime.trigger_and_wait(entry_point, input_data, timeout, session_state=session_state)
|
||||
return await self._runtime.trigger_and_wait(
|
||||
entry_point, input_data, timeout, session_state=session_state
|
||||
)
|
||||
|
||||
async def run(self, context: dict, mock_mode=False, session_state=None) -> ExecutionResult:
|
||||
async def run(
|
||||
self, context: dict, mock_mode=False, session_state=None
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Run the agent (convenience method for simple single execution).
|
||||
|
||||
@@ -342,7 +353,9 @@ class OnlineResearchAgent:
|
||||
else:
|
||||
entry_point = "start"
|
||||
|
||||
result = await self.trigger_and_wait(entry_point, context, session_state=session_state)
|
||||
result = await self.trigger_and_wait(
|
||||
entry_point, context, session_state=session_state
|
||||
)
|
||||
return result or ExecutionResult(success=False, error="Execution timeout")
|
||||
finally:
|
||||
await self.stop()
|
||||
@@ -404,7 +417,9 @@ class OnlineResearchAgent:
|
||||
# Validate entry points
|
||||
for ep_id, node_id in self.entry_points.items():
|
||||
if node_id not in node_ids:
|
||||
errors.append(f"Entry point '{ep_id}' references unknown node '{node_id}'")
|
||||
errors.append(
|
||||
f"Entry point '{ep_id}' references unknown node '{node_id}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Runtime configuration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ class RuntimeConfig:
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
|
||||
|
||||
# Agent metadata
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
|
||||
+103
-20
@@ -1,4 +1,5 @@
|
||||
"""Node definitions for Online Research Agent."""
|
||||
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
# Node 1: Parse Query
|
||||
@@ -10,9 +11,21 @@ parse_query_node = NodeSpec(
|
||||
input_keys=["topic"],
|
||||
output_keys=["search_queries", "research_focus", "key_aspects"],
|
||||
output_schema={
|
||||
"research_focus": {"type": "string", "required": True, "description": "Brief statement of what we're researching"},
|
||||
"key_aspects": {"type": "array", "required": True, "description": "List of 3-5 key aspects to investigate"},
|
||||
"search_queries": {"type": "array", "required": True, "description": "List of 3-5 search queries"},
|
||||
"research_focus": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Brief statement of what we're researching",
|
||||
},
|
||||
"key_aspects": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of 3-5 key aspects to investigate",
|
||||
},
|
||||
"search_queries": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of 3-5 search queries",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research query strategist. Given a research topic, analyze it and generate search queries.
|
||||
@@ -50,8 +63,16 @@ search_sources_node = NodeSpec(
|
||||
input_keys=["search_queries", "research_focus"],
|
||||
output_keys=["source_urls", "search_results_summary"],
|
||||
output_schema={
|
||||
"source_urls": {"type": "array", "required": True, "description": "List of source URLs found"},
|
||||
"search_results_summary": {"type": "string", "required": True, "description": "Brief summary of what was found"},
|
||||
"source_urls": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of source URLs found",
|
||||
},
|
||||
"search_results_summary": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Brief summary of what was found",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research assistant executing web searches. Use the web_search tool to find sources.
|
||||
@@ -80,8 +101,16 @@ fetch_content_node = NodeSpec(
|
||||
input_keys=["source_urls", "research_focus"],
|
||||
output_keys=["fetched_sources", "fetch_errors"],
|
||||
output_schema={
|
||||
"fetched_sources": {"type": "array", "required": True, "description": "List of fetched source objects with url, title, content"},
|
||||
"fetch_errors": {"type": "array", "required": True, "description": "List of URLs that failed to fetch"},
|
||||
"fetched_sources": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of fetched source objects with url, title, content",
|
||||
},
|
||||
"fetch_errors": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of URLs that failed to fetch",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a content fetcher. Use web_scrape tool to retrieve content from URLs.
|
||||
@@ -113,8 +142,16 @@ evaluate_sources_node = NodeSpec(
|
||||
input_keys=["fetched_sources", "research_focus", "key_aspects"],
|
||||
output_keys=["ranked_sources", "source_analysis"],
|
||||
output_schema={
|
||||
"ranked_sources": {"type": "array", "required": True, "description": "List of ranked sources with scores"},
|
||||
"source_analysis": {"type": "string", "required": True, "description": "Overview of source quality and coverage"},
|
||||
"ranked_sources": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of ranked sources with scores",
|
||||
},
|
||||
"source_analysis": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Overview of source quality and coverage",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a source evaluator. Assess each source for quality and relevance.
|
||||
@@ -153,9 +190,21 @@ synthesize_findings_node = NodeSpec(
|
||||
input_keys=["ranked_sources", "research_focus", "key_aspects"],
|
||||
output_keys=["key_findings", "themes", "source_citations"],
|
||||
output_schema={
|
||||
"key_findings": {"type": "array", "required": True, "description": "List of key findings with sources and confidence"},
|
||||
"themes": {"type": "array", "required": True, "description": "List of themes with descriptions and supporting sources"},
|
||||
"source_citations": {"type": "object", "required": True, "description": "Map of facts to supporting URLs"},
|
||||
"key_findings": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of key findings with sources and confidence",
|
||||
},
|
||||
"themes": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of themes with descriptions and supporting sources",
|
||||
},
|
||||
"source_citations": {
|
||||
"type": "object",
|
||||
"required": True,
|
||||
"description": "Map of facts to supporting URLs",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research synthesizer. Analyze multiple sources to extract insights.
|
||||
@@ -192,11 +241,25 @@ write_report_node = NodeSpec(
|
||||
name="Write Report",
|
||||
description="Generate a narrative report with proper citations",
|
||||
node_type="llm_generate",
|
||||
input_keys=["key_findings", "themes", "source_citations", "research_focus", "ranked_sources"],
|
||||
input_keys=[
|
||||
"key_findings",
|
||||
"themes",
|
||||
"source_citations",
|
||||
"research_focus",
|
||||
"ranked_sources",
|
||||
],
|
||||
output_keys=["report_content", "references"],
|
||||
output_schema={
|
||||
"report_content": {"type": "string", "required": True, "description": "Full markdown report text with citations"},
|
||||
"references": {"type": "array", "required": True, "description": "List of reference objects with number, url, title"},
|
||||
"report_content": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Full markdown report text with citations",
|
||||
},
|
||||
"references": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of reference objects with number, url, title",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a research report writer. Create a well-structured narrative report.
|
||||
@@ -239,9 +302,21 @@ quality_check_node = NodeSpec(
|
||||
input_keys=["report_content", "references", "source_citations"],
|
||||
output_keys=["quality_score", "issues", "final_report"],
|
||||
output_schema={
|
||||
"quality_score": {"type": "number", "required": True, "description": "Quality score 0-1"},
|
||||
"issues": {"type": "array", "required": True, "description": "List of issues found and fixed"},
|
||||
"final_report": {"type": "string", "required": True, "description": "Corrected full report"},
|
||||
"quality_score": {
|
||||
"type": "number",
|
||||
"required": True,
|
||||
"description": "Quality score 0-1",
|
||||
},
|
||||
"issues": {
|
||||
"type": "array",
|
||||
"required": True,
|
||||
"description": "List of issues found and fixed",
|
||||
},
|
||||
"final_report": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Corrected full report",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a quality assurance reviewer. Check the research report for issues.
|
||||
@@ -278,8 +353,16 @@ save_report_node = NodeSpec(
|
||||
input_keys=["final_report", "references", "research_focus"],
|
||||
output_keys=["file_path", "save_status"],
|
||||
output_schema={
|
||||
"file_path": {"type": "string", "required": True, "description": "Path where report was saved"},
|
||||
"save_status": {"type": "string", "required": True, "description": "Status of save operation"},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Path where report was saved",
|
||||
},
|
||||
"save_status": {
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "Status of save operation",
|
||||
},
|
||||
},
|
||||
system_prompt="""\
|
||||
You are a file manager. Save the research report to disk.
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
# Triage Issue Skill
|
||||
|
||||
Analyze a GitHub issue, verify claims against the codebase, and close invalid issues with a technical response.
|
||||
|
||||
## Trigger
|
||||
|
||||
User provides a GitHub issue URL or number, e.g.:
|
||||
- `/triage-issue 1970`
|
||||
- `/triage-issue https://github.com/adenhq/hive/issues/1970`
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Fetch Issue Details
|
||||
|
||||
```bash
|
||||
gh issue view <number> --repo adenhq/hive --json title,body,state,labels,author
|
||||
```
|
||||
|
||||
Extract:
|
||||
- Title
|
||||
- Body (the claim/bug report)
|
||||
- Current state
|
||||
- Labels
|
||||
- Author
|
||||
|
||||
If issue is already closed, inform user and stop.
|
||||
|
||||
### Step 2: Analyze the Claim
|
||||
|
||||
Read the issue body and identify:
|
||||
1. **The core claim** - What is the user asserting?
|
||||
2. **Technical specifics** - File paths, function names, code snippets mentioned
|
||||
3. **Expected behavior** - What do they think should happen?
|
||||
4. **Severity claimed** - Security issue? Bug? Feature request?
|
||||
|
||||
### Step 3: Investigate the Codebase
|
||||
|
||||
For each technical claim:
|
||||
1. Find the referenced code using Grep/Glob/Read
|
||||
2. Understand the actual implementation
|
||||
3. Check if the claim accurately describes the behavior
|
||||
4. Look for related tests, documentation, or design decisions
|
||||
|
||||
### Step 4: Evaluate Validity
|
||||
|
||||
Categorize the issue as one of:
|
||||
|
||||
| Category | Action |
|
||||
|----------|--------|
|
||||
| **Valid Bug** | Do NOT close. Inform user this is a real issue. |
|
||||
| **Valid Feature Request** | Do NOT close. Suggest labeling appropriately. |
|
||||
| **Misunderstanding** | Prepare technical explanation for why behavior is correct. |
|
||||
| **Fundamentally Flawed** | Prepare critique explaining the technical impossibility or design rationale. |
|
||||
| **Duplicate** | Find the original issue and prepare duplicate notice. |
|
||||
| **Incomplete** | Prepare request for more information. |
|
||||
|
||||
### Step 5: Draft Response
|
||||
|
||||
For issues to be closed, draft a response that:
|
||||
|
||||
1. **Acknowledges the concern** - Don't be dismissive
|
||||
2. **Explains the actual behavior** - With code references
|
||||
3. **Provides technical rationale** - Why it works this way
|
||||
4. **References industry standards** - If applicable
|
||||
5. **Offers alternatives** - If there's a better approach for the user
|
||||
|
||||
Use this template:
|
||||
|
||||
```markdown
|
||||
## Analysis
|
||||
|
||||
[Brief summary of what was investigated]
|
||||
|
||||
## Technical Details
|
||||
|
||||
[Explanation with code references]
|
||||
|
||||
## Why This Is Working As Designed
|
||||
|
||||
[Rationale]
|
||||
|
||||
## Recommendation
|
||||
|
||||
[What the user should do instead, if applicable]
|
||||
|
||||
---
|
||||
*This issue was reviewed and closed by the maintainers.*
|
||||
```
|
||||
|
||||
### Step 6: User Review
|
||||
|
||||
Present the draft to the user with:
|
||||
|
||||
```
|
||||
## Issue #<number>: <title>
|
||||
|
||||
**Claim:** <summary of claim>
|
||||
|
||||
**Finding:** <valid/invalid/misunderstanding/etc>
|
||||
|
||||
**Draft Response:**
|
||||
<the markdown response>
|
||||
|
||||
---
|
||||
Do you want me to post this comment and close the issue?
|
||||
```
|
||||
|
||||
Use AskUserQuestion with options:
|
||||
- "Post and close" - Post comment, close issue
|
||||
- "Edit response" - Let user modify the response
|
||||
- "Skip" - Don't take action
|
||||
|
||||
### Step 7: Execute Action
|
||||
|
||||
If user approves:
|
||||
|
||||
```bash
|
||||
# Post comment
|
||||
gh issue comment <number> --repo adenhq/hive --body "<response>"
|
||||
|
||||
# Close issue
|
||||
gh issue close <number> --repo adenhq/hive --reason "not planned"
|
||||
```
|
||||
|
||||
Report success with link to the issue.
|
||||
|
||||
## Important Guidelines
|
||||
|
||||
1. **Never close valid issues** - If there's any merit to the claim, don't close it
|
||||
2. **Be respectful** - The reporter took time to file the issue
|
||||
3. **Be technical** - Provide code references and evidence
|
||||
4. **Be educational** - Help them understand, don't just dismiss
|
||||
5. **Check twice** - Make sure you understand the code before declaring something invalid
|
||||
6. **Consider edge cases** - Maybe their environment reveals a real issue
|
||||
|
||||
## Example Critiques
|
||||
|
||||
### Security Misunderstanding
|
||||
> "The claim that secrets are exposed in plaintext misunderstands the encryption architecture. While `SecretStr` is used for logging protection, actual encryption is provided by Fernet (AES-128-CBC) at the storage layer. The code path is: serialize → encrypt → write. Only encrypted bytes touch disk."
|
||||
|
||||
### Impossible Request
|
||||
> "The requested feature would require [X] which violates [fundamental constraint]. This is not a limitation of our implementation but a fundamental property of [technology/protocol]."
|
||||
|
||||
### Already Handled
|
||||
> "This scenario is already handled by [code reference]. The reporter may be using an older version or misconfigured environment."
|
||||
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "core",
|
||||
"env": {
|
||||
"PYTHONPATH": "../tools/src"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"command": "python",
|
||||
"args": ["mcp_server.py", "--stdio"],
|
||||
"cwd": "tools",
|
||||
"env": {
|
||||
"PYTHONPATH": "src"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/agent-workflow
|
||||
@@ -0,0 +1 @@
|
||||
../../.claude/skills/building-agents-construction
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/building-agents-core
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/building-agents-patterns
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../.claude/skills/testing-agent
|
||||
@@ -0,0 +1,18 @@
|
||||
This project uses ruff for Python linting and formatting.
|
||||
|
||||
Rules:
|
||||
- Line length: 100 characters
|
||||
- Python target: 3.11+
|
||||
- Use double quotes for strings
|
||||
- Sort imports with isort (ruff I rules): stdlib, third-party, first-party (framework), local
|
||||
- Combine as-imports
|
||||
- Use type hints on all function signatures
|
||||
- Use `from __future__ import annotations` for modern type syntax
|
||||
- Raise exceptions with `from` in except blocks (B904)
|
||||
- No unused imports (F401), no unused variables (F841)
|
||||
- Prefer list/dict/set comprehensions over map/filter (C4)
|
||||
|
||||
Run `make lint` to auto-fix, `make check` to verify without modifying files.
|
||||
Run `make format` to apply ruff formatting.
|
||||
|
||||
The ruff config lives in core/pyproject.toml under [tool.ruff].
|
||||
@@ -11,6 +11,9 @@ indent_size = 2
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[*.py]
|
||||
indent_size = 4
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
|
||||
+124
@@ -0,0 +1,124 @@
|
||||
# Normalize line endings for all text files
|
||||
* text=auto
|
||||
|
||||
# Source code
|
||||
*.py text diff=python
|
||||
*.js text
|
||||
*.ts text
|
||||
*.jsx text
|
||||
*.tsx text
|
||||
*.json text
|
||||
*.yaml text
|
||||
*.yml text
|
||||
*.toml text
|
||||
*.ini text
|
||||
*.cfg text
|
||||
|
||||
# Shell scripts (must use LF)
|
||||
*.sh text eol=lf
|
||||
quickstart.sh text eol=lf
|
||||
|
||||
# PowerShell scripts (Windows-friendly)
|
||||
*.ps1 text eol=lf
|
||||
*.psm1 text eol=lf
|
||||
|
||||
# Windows batch files (must use CRLF)
|
||||
*.bat text eol=crlf
|
||||
*.cmd text eol=crlf
|
||||
|
||||
# Documentation
|
||||
*.md text
|
||||
*.txt text
|
||||
*.rst text
|
||||
*.tex text
|
||||
|
||||
# Configuration files
|
||||
.gitignore text
|
||||
.gitattributes text
|
||||
.editorconfig text
|
||||
Dockerfile text
|
||||
docker-compose.yml text
|
||||
requirements*.txt text
|
||||
pyproject.toml text
|
||||
setup.py text
|
||||
setup.cfg text
|
||||
MANIFEST.in text
|
||||
LICENSE text
|
||||
README* text
|
||||
CHANGELOG* text
|
||||
CONTRIBUTING* text
|
||||
CODE_OF_CONDUCT* text
|
||||
|
||||
# Web files
|
||||
*.html text
|
||||
*.css text
|
||||
*.scss text
|
||||
*.sass text
|
||||
|
||||
# Data files
|
||||
*.xml text
|
||||
*.csv text
|
||||
*.sql text
|
||||
|
||||
# Graphics (binary)
|
||||
*.png binary
|
||||
*.jpg binary
|
||||
*.jpeg binary
|
||||
*.gif binary
|
||||
*.ico binary
|
||||
*.svg binary
|
||||
*.eps binary
|
||||
*.bmp binary
|
||||
*.tif binary
|
||||
*.tiff binary
|
||||
|
||||
# Archives (binary)
|
||||
*.zip binary
|
||||
*.tar binary
|
||||
*.gz binary
|
||||
*.bz2 binary
|
||||
*.7z binary
|
||||
*.rar binary
|
||||
|
||||
# Python compiled (binary)
|
||||
*.pyc binary
|
||||
*.pyo binary
|
||||
*.pyd binary
|
||||
*.whl binary
|
||||
*.egg binary
|
||||
|
||||
# System libraries (binary)
|
||||
*.so binary
|
||||
*.dll binary
|
||||
*.dylib binary
|
||||
*.lib binary
|
||||
*.a binary
|
||||
|
||||
# Documents (binary)
|
||||
*.pdf binary
|
||||
*.doc binary
|
||||
*.docx binary
|
||||
*.ppt binary
|
||||
*.pptx binary
|
||||
*.xls binary
|
||||
*.xlsx binary
|
||||
|
||||
# Fonts (binary)
|
||||
*.ttf binary
|
||||
*.otf binary
|
||||
*.woff binary
|
||||
*.woff2 binary
|
||||
*.eot binary
|
||||
|
||||
# Audio/Video (binary)
|
||||
*.mp3 binary
|
||||
*.mp4 binary
|
||||
*.wav binary
|
||||
*.avi binary
|
||||
*.mov binary
|
||||
*.flv binary
|
||||
|
||||
# Database files (binary)
|
||||
*.db binary
|
||||
*.sqlite binary
|
||||
*.sqlite3 binary
|
||||
@@ -8,7 +8,6 @@
|
||||
/hive/ @adenhq/maintainers
|
||||
|
||||
# Infrastructure
|
||||
/docker-compose*.yml @adenhq/maintainers
|
||||
/.github/ @adenhq/maintainers
|
||||
|
||||
# Documentation
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
name: Auto-close duplicate issues
|
||||
description: Auto-closes issues that are duplicates of existing issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 */6 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
auto-close-duplicates:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Auto-close duplicate issues
|
||||
run: bun run scripts/auto-close-duplicates.ts
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
STATSIG_API_KEY: ${{ secrets.STATSIG_API_KEY }}
|
||||
@@ -29,10 +29,15 @@ jobs:
|
||||
pip install -e .
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
- name: Run ruff
|
||||
- name: Ruff lint
|
||||
run: |
|
||||
cd core
|
||||
ruff check .
|
||||
ruff check core/
|
||||
ruff check tools/
|
||||
|
||||
- name: Ruff format
|
||||
run: |
|
||||
ruff format --check core/
|
||||
ruff format --check tools/
|
||||
|
||||
test:
|
||||
name: Test Python Framework
|
||||
@@ -79,9 +84,31 @@ jobs:
|
||||
- name: Validate exported agents
|
||||
run: |
|
||||
# Check that agent exports have valid structure
|
||||
for agent_dir in exports/*/; do
|
||||
if [ ! -d "exports" ]; then
|
||||
echo "No exports/ directory found, skipping validation"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
shopt -s nullglob
|
||||
agent_dirs=(exports/*/)
|
||||
shopt -u nullglob
|
||||
|
||||
if [ ${#agent_dirs[@]} -eq 0 ]; then
|
||||
echo "No agent directories in exports/, skipping validation"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
validated=0
|
||||
for agent_dir in "${agent_dirs[@]}"; do
|
||||
if [ -f "$agent_dir/agent.json" ]; then
|
||||
echo "Validating $agent_dir"
|
||||
python -c "import json; json.load(open('$agent_dir/agent.json'))"
|
||||
validated=$((validated + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$validated" -eq 0 ]; then
|
||||
echo "No agent.json files found in exports/, skipping validation"
|
||||
else
|
||||
echo "Validated $validated agent(s)"
|
||||
fi
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Claude Issue Triage
|
||||
name: Issue Triage
|
||||
|
||||
on:
|
||||
issues:
|
||||
@@ -7,9 +7,11 @@ on:
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -17,52 +19,79 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Run Claude Issue Triage
|
||||
id: claude-triage
|
||||
- name: Triage and check for duplicates
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
allowed_non_write_users: '*'
|
||||
allowed_non_write_users: "*"
|
||||
prompt: |
|
||||
REPO: ${{ github.repository }}
|
||||
ISSUE NUMBER: ${{ github.event.issue.number }}
|
||||
TITLE: ${{ github.event.issue.title }}
|
||||
BODY: ${{ github.event.issue.body }}
|
||||
AUTHOR: ${{ github.event.issue.user.login }}
|
||||
Analyze this new issue and perform triage tasks.
|
||||
|
||||
Analyze this new issue and perform the following:
|
||||
Issue: #${{ github.event.issue.number }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
1. **Categorize the issue type using ONLY these labels:**
|
||||
- bug: Something isn't working
|
||||
- enhancement: New feature or request
|
||||
- question: Further information is requested
|
||||
- documentation: Improvements or additions to documentation
|
||||
- good first issue: Good for newcomers (if issue is well-defined and small scope)
|
||||
- help wanted: Extra attention is needed (if issue needs community input)
|
||||
- backlog: Tracked for the future, but not currently planned or prioritized
|
||||
## Your Tasks:
|
||||
|
||||
2. **Check for duplicates:**
|
||||
Search for similar existing issues using:
|
||||
`gh issue list --state all --search "<key terms from title/body>"`
|
||||
### 1. Get issue details
|
||||
Use mcp__github__get_issue to get the full details of issue #${{ github.event.issue.number }}
|
||||
|
||||
If a duplicate exists:
|
||||
- Add the "duplicate" label
|
||||
- Comment mentioning the original issue number
|
||||
### 2. Check for duplicates
|
||||
Search for similar existing issues using mcp__github__search_issues with relevant keywords from the issue title and body.
|
||||
|
||||
3. **Check for invalid issues:**
|
||||
If the issue lacks sufficient information, is spam, or doesn't make sense:
|
||||
- Add the "invalid" label
|
||||
- Comment asking for clarification or explaining why it's invalid
|
||||
Criteria for duplicates:
|
||||
- Same bug or error being reported
|
||||
- Same feature request (even if worded differently)
|
||||
- Same question being asked
|
||||
- Issues describing the same root problem
|
||||
|
||||
4. **Apply labels:**
|
||||
Based on your analysis, add appropriate labels using:
|
||||
`gh issue edit ${{ github.event.issue.number }} --add-label "label1,label2"`
|
||||
If you find a duplicate:
|
||||
- Add a comment using EXACTLY this format (required for auto-close to work):
|
||||
"Found a possible duplicate of #<issue_number>: <brief explanation of why it's a duplicate>"
|
||||
- Do NOT apply the "duplicate" label yet (the auto-close script will add it after 12 hours if no objections)
|
||||
- Suggest the user react with a thumbs-down if they disagree
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug,help wanted").
|
||||
### 3. Check for Low-Quality / AI Spam
|
||||
Analyze the issue quality. We are receiving many low-effort, AI-generated spam issues.
|
||||
Flag the issue as INVALID if it matches these criteria:
|
||||
- **Vague/Generic**: Title is "Fix bug" or "Error" without specific context.
|
||||
- **Hallucinated**: Refers to files or features that do not exist in this repo.
|
||||
- **Template Filler**: Body contains "Insert description here" or unrelated gibberish.
|
||||
- **Low Effort**: No reproduction steps, no logs, only 1-2 sentences.
|
||||
|
||||
5. **Add a brief comment** summarizing your triage decision to help maintainers.
|
||||
If identified as spam/low-quality:
|
||||
- Add the "invalid" label.
|
||||
- Add a comment:
|
||||
"This issue has been automatically flagged as low-quality or potentially AI-generated spam. It lacks specific details (logs, reproduction steps, file references) required for us to help. Please open a new issue following the template exactly if this is a legitimate request."
|
||||
- Do NOT proceed to other steps.
|
||||
|
||||
### 4. Check for invalid issues (General)
|
||||
If the issue is not spam but still lacks information:
|
||||
- Add the "invalid" label
|
||||
- Comment asking for clarification
|
||||
|
||||
### 5. Categorize with labels (if NOT a duplicate or spam)
|
||||
Apply appropriate labels based on the issue content. Use ONLY these labels:
|
||||
- bug: Something isn't working
|
||||
- enhancement: New feature or request
|
||||
- question: Further information is requested
|
||||
- documentation: Improvements or additions to documentation
|
||||
- good first issue: Good for newcomers (if issue is well-defined and small scope)
|
||||
- help wanted: Extra attention is needed (if issue needs community input)
|
||||
- backlog: Tracked for the future, but not currently planned or prioritized
|
||||
|
||||
You may apply multiple labels if appropriate (e.g., "bug" and "help wanted").
|
||||
|
||||
## Tools Available:
|
||||
- mcp__github__get_issue: Get issue details
|
||||
- mcp__github__search_issues: Search for similar issues
|
||||
- mcp__github__list_issues: List recent issues if needed
|
||||
- mcp__github__add_issue_comment: Add a comment
|
||||
- mcp__github__update_issue: Add labels
|
||||
- mcp__github__get_issue_comments: Get existing comments
|
||||
|
||||
Be thorough but efficient. Focus on accurate categorization and finding true duplicates.
|
||||
|
||||
claude_args: |
|
||||
--model claude-3-5-haiku-20241022
|
||||
--allowedTools "Bash(gh issue:*),Bash(gh search:*)"
|
||||
--model claude-haiku-4-5-20251001
|
||||
--allowedTools "mcp__github__get_issue,mcp__github__search_issues,mcp__github__list_issues,mcp__github__add_issue_comment,mcp__github__update_issue,mcp__github__get_issue_comments"
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
name: PR Check Command
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
check-pr:
|
||||
# Only run on PR comments that start with /check
|
||||
if: github.event.issue.pull_request && startsWith(github.event.comment.body, '/check')
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
steps:
|
||||
- name: Check PR requirements
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
console.log(`Triggered by /check comment on PR #${prNumber}`);
|
||||
|
||||
// Fetch PR data
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
});
|
||||
|
||||
const prBody = pr.body || '';
|
||||
const prTitle = pr.title || '';
|
||||
const prAuthor = pr.user.login;
|
||||
const headSha = pr.head.sha;
|
||||
|
||||
// Create a check run in progress
|
||||
const { data: checkRun } = await github.rest.checks.create({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
name: 'check-requirements',
|
||||
head_sha: headSha,
|
||||
status: 'in_progress',
|
||||
started_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
// Extract issue numbers
|
||||
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
|
||||
const allText = `${prTitle} ${prBody}`;
|
||||
const matches = [...allText.matchAll(issuePattern)];
|
||||
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
|
||||
|
||||
console.log(`PR #${prNumber}:`);
|
||||
console.log(` Author: ${prAuthor}`);
|
||||
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
|
||||
|
||||
if (issueNumbers.length === 0) {
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**Missing:** No linked issue found.
|
||||
|
||||
**To fix:**
|
||||
1. Create or find an existing issue for this work
|
||||
2. Assign yourself to the issue
|
||||
3. Re-open this PR and add \`Fixes #123\` in the description
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
// Update check run to failure
|
||||
await github.rest.checks.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
check_run_id: checkRun.id,
|
||||
status: 'completed',
|
||||
conclusion: 'failure',
|
||||
completed_at: new Date().toISOString(),
|
||||
output: {
|
||||
title: 'Missing linked issue',
|
||||
summary: 'PR must reference an issue (e.g., `Fixes #123`)',
|
||||
},
|
||||
});
|
||||
|
||||
core.setFailed('PR must reference an issue');
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if PR author is assigned to any linked issue
|
||||
let issueWithAuthorAssigned = null;
|
||||
let issuesWithoutAuthor = [];
|
||||
|
||||
for (const issueNum of issueNumbers) {
|
||||
try {
|
||||
const { data: issue } = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNum,
|
||||
});
|
||||
|
||||
const assigneeLogins = (issue.assignees || []).map(a => a.login);
|
||||
if (assigneeLogins.includes(prAuthor)) {
|
||||
issueWithAuthorAssigned = issueNum;
|
||||
console.log(` Issue #${issueNum} has PR author ${prAuthor} as assignee`);
|
||||
break;
|
||||
} else {
|
||||
issuesWithoutAuthor.push({
|
||||
number: issueNum,
|
||||
assignees: assigneeLogins
|
||||
});
|
||||
console.log(` Issue #${issueNum} assignees: ${assigneeLogins.length > 0 ? assigneeLogins.join(', ') : 'none'}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(` Issue #${issueNum} not found`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!issueWithAuthorAssigned) {
|
||||
const issueList = issuesWithoutAuthor.map(i =>
|
||||
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
|
||||
).join(', ');
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**PR Author:** @${prAuthor}
|
||||
**Found issues:** ${issueList}
|
||||
**Problem:** The PR author must be assigned to the linked issue.
|
||||
|
||||
**To fix:**
|
||||
1. Assign yourself (@${prAuthor}) to one of the linked issues
|
||||
2. Re-open this PR
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
// Update check run to failure
|
||||
await github.rest.checks.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
check_run_id: checkRun.id,
|
||||
status: 'completed',
|
||||
conclusion: 'failure',
|
||||
completed_at: new Date().toISOString(),
|
||||
output: {
|
||||
title: 'PR author not assigned to issue',
|
||||
summary: `PR author @${prAuthor} must be assigned to one of the linked issues: ${issueList}`,
|
||||
},
|
||||
});
|
||||
|
||||
core.setFailed('PR author must be assigned to the linked issue');
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `✅ PR requirements met! Issue #${issueWithAuthorAssigned} has @${prAuthor} as assignee.`,
|
||||
});
|
||||
|
||||
// Update check run to success
|
||||
await github.rest.checks.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
check_run_id: checkRun.id,
|
||||
status: 'completed',
|
||||
conclusion: 'success',
|
||||
completed_at: new Date().toISOString(),
|
||||
output: {
|
||||
title: 'Requirements met',
|
||||
summary: `Issue #${issueWithAuthorAssigned} has @${prAuthor} as assignee.`,
|
||||
},
|
||||
});
|
||||
|
||||
console.log(`PR requirements met!`);
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
name: PR Requirements Backfill
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-all-open-prs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Check all open PRs
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { data: pullRequests } = await github.rest.pulls.list({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
state: 'open',
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
console.log(`Found ${pullRequests.length} open PRs`);
|
||||
|
||||
for (const pr of pullRequests) {
|
||||
const prNumber = pr.number;
|
||||
const prBody = pr.body || '';
|
||||
const prTitle = pr.title || '';
|
||||
const prAuthor = pr.user.login;
|
||||
|
||||
console.log(`\nChecking PR #${prNumber}: ${prTitle}`);
|
||||
|
||||
// Extract issue numbers from body and title
|
||||
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
|
||||
const allText = `${prTitle} ${prBody}`;
|
||||
const matches = [...allText.matchAll(issuePattern)];
|
||||
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
|
||||
|
||||
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
|
||||
|
||||
if (issueNumbers.length === 0) {
|
||||
console.log(` ❌ No linked issue - closing PR`);
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**Missing:** No linked issue found.
|
||||
|
||||
**To fix:**
|
||||
1. Create or find an existing issue for this work
|
||||
2. Assign yourself to the issue
|
||||
3. Re-open this PR and add \`Fixes #123\` in the description`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if any linked issue has the PR author as assignee
|
||||
let issueWithAuthorAssigned = null;
|
||||
let issuesWithoutAuthor = [];
|
||||
|
||||
for (const issueNum of issueNumbers) {
|
||||
try {
|
||||
const { data: issue } = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNum,
|
||||
});
|
||||
|
||||
const assigneeLogins = (issue.assignees || []).map(a => a.login);
|
||||
if (assigneeLogins.includes(prAuthor)) {
|
||||
issueWithAuthorAssigned = issueNum;
|
||||
break;
|
||||
} else {
|
||||
issuesWithoutAuthor.push({
|
||||
number: issueNum,
|
||||
assignees: assigneeLogins
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(` Issue #${issueNum} not found or inaccessible`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!issueWithAuthorAssigned) {
|
||||
const issueList = issuesWithoutAuthor.map(i =>
|
||||
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
|
||||
).join(', ');
|
||||
|
||||
console.log(` ❌ PR author not assigned to any linked issue - closing PR`);
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**PR Author:** @${prAuthor}
|
||||
**Found issues:** ${issueList}
|
||||
**Problem:** The PR author must be assigned to the linked issue.
|
||||
|
||||
**To fix:**
|
||||
1. Assign yourself (@${prAuthor}) to one of the linked issues
|
||||
2. Re-open this PR`;
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
} else {
|
||||
console.log(` ✅ PR requirements met! Issue #${issueWithAuthorAssigned} has ${prAuthor} as assignee.`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log('\nBackfill complete!');
|
||||
@@ -0,0 +1,189 @@
|
||||
name: PR Requirements Check
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited, synchronize]
|
||||
|
||||
jobs:
|
||||
check-requirements:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- name: Check PR has linked issue with assignee
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const pr = context.payload.pull_request;
|
||||
const prNumber = pr.number;
|
||||
const prBody = pr.body || '';
|
||||
const prTitle = pr.title || '';
|
||||
const prLabels = (pr.labels || []).map(l => l.name);
|
||||
|
||||
// Allow micro-fix and documentation PRs without a linked issue
|
||||
const isMicroFix = prLabels.includes('micro-fix') || /micro-fix/i.test(prTitle);
|
||||
const isDocumentation = prLabels.includes('documentation') || /\bdocs?\b/i.test(prTitle);
|
||||
if (isMicroFix || isDocumentation) {
|
||||
const reason = isMicroFix ? 'micro-fix' : 'documentation';
|
||||
console.log(`PR #${prNumber} is a ${reason}, skipping issue requirement.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract issue numbers from body and title
|
||||
// Matches: fixes #123, closes #123, resolves #123, or plain #123
|
||||
const issuePattern = /(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)?\s*#(\d+)/gi;
|
||||
|
||||
const allText = `${prTitle} ${prBody}`;
|
||||
const matches = [...allText.matchAll(issuePattern)];
|
||||
const issueNumbers = [...new Set(matches.map(m => parseInt(m[1], 10)))];
|
||||
|
||||
console.log(`PR #${prNumber}:`);
|
||||
console.log(` Found issue references: ${issueNumbers.length > 0 ? issueNumbers.join(', ') : 'none'}`);
|
||||
|
||||
if (issueNumbers.length === 0) {
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**Missing:** No linked issue found.
|
||||
|
||||
**To fix:**
|
||||
1. Create or find an existing issue for this work
|
||||
2. Assign yourself to the issue
|
||||
3. Re-open this PR and add \`Fixes #123\` in the description
|
||||
|
||||
**Exception:** To bypass this requirement, you can:
|
||||
- Add the \`micro-fix\` label or include \`micro-fix\` in your PR title for trivial fixes
|
||||
- Add the \`documentation\` label or include \`doc\`/\`docs\` in your PR title for documentation changes
|
||||
|
||||
**Micro-fix requirements** (must meet ALL):
|
||||
| Qualifies | Disqualifies |
|
||||
|-----------|--------------|
|
||||
| < 20 lines changed | Any functional bug fix |
|
||||
| Typos & Documentation & Linting | Refactoring for "clean code" |
|
||||
| No logic/API/DB changes | New features (even tiny ones) |
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
const comments = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const botComment = comments.data.find(
|
||||
(c) => c.user.type === 'Bot' && c.body.includes('PR Closed - Requirements Not Met')
|
||||
);
|
||||
|
||||
if (!botComment) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
}
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
core.setFailed('PR must reference an issue');
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if any linked issue has the PR author as assignee
|
||||
const prAuthor = pr.user.login;
|
||||
let issueWithAuthorAssigned = null;
|
||||
let issuesWithoutAuthor = [];
|
||||
|
||||
for (const issueNum of issueNumbers) {
|
||||
try {
|
||||
const { data: issue } = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNum,
|
||||
});
|
||||
|
||||
const assigneeLogins = (issue.assignees || []).map(a => a.login);
|
||||
if (assigneeLogins.includes(prAuthor)) {
|
||||
issueWithAuthorAssigned = issueNum;
|
||||
console.log(` Issue #${issueNum} has PR author ${prAuthor} as assignee`);
|
||||
break;
|
||||
} else {
|
||||
issuesWithoutAuthor.push({
|
||||
number: issueNum,
|
||||
assignees: assigneeLogins
|
||||
});
|
||||
console.log(` Issue #${issueNum} assignees: ${assigneeLogins.length > 0 ? assigneeLogins.join(', ') : 'none'} (PR author: ${prAuthor})`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(` Issue #${issueNum} not found or inaccessible`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!issueWithAuthorAssigned) {
|
||||
const issueList = issuesWithoutAuthor.map(i =>
|
||||
`#${i.number} (assignees: ${i.assignees.length > 0 ? i.assignees.join(', ') : 'none'})`
|
||||
).join(', ');
|
||||
|
||||
const message = `## PR Closed - Requirements Not Met
|
||||
|
||||
This PR has been automatically closed because it doesn't meet the requirements.
|
||||
|
||||
**PR Author:** @${prAuthor}
|
||||
**Found issues:** ${issueList}
|
||||
**Problem:** The PR author must be assigned to the linked issue.
|
||||
|
||||
**To fix:**
|
||||
1. Assign yourself (@${prAuthor}) to one of the linked issues
|
||||
2. Re-open this PR
|
||||
|
||||
**Exception:** To bypass this requirement, you can:
|
||||
- Add the \`micro-fix\` label or include \`micro-fix\` in your PR title for trivial fixes
|
||||
- Add the \`documentation\` label or include \`doc\`/\`docs\` in your PR title for documentation changes
|
||||
|
||||
**Micro-fix requirements** (must meet ALL):
|
||||
| Qualifies | Disqualifies |
|
||||
|-----------|--------------|
|
||||
| < 20 lines changed | Any functional bug fix |
|
||||
| Typos & Documentation & Linting | Refactoring for "clean code" |
|
||||
| No logic/API/DB changes | New features (even tiny ones) |
|
||||
|
||||
**Why is this required?** See #472 for details.`;
|
||||
|
||||
const comments = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const botComment = comments.data.find(
|
||||
(c) => c.user.type === 'Bot' && c.body.includes('PR Closed - Requirements Not Met')
|
||||
);
|
||||
|
||||
if (!botComment) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: message,
|
||||
});
|
||||
}
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed',
|
||||
});
|
||||
|
||||
core.setFailed('PR author must be assigned to the linked issue');
|
||||
} else {
|
||||
console.log(`PR requirements met! Issue #${issueWithAuthorAssigned} has ${prAuthor} as assignee.`);
|
||||
}
|
||||
+1
-1
@@ -69,4 +69,4 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.venv
|
||||
.venv
|
||||
@@ -0,0 +1,18 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.6
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff lint (core)
|
||||
args: [--fix]
|
||||
files: ^core/
|
||||
- id: ruff
|
||||
name: ruff lint (tools)
|
||||
args: [--fix]
|
||||
files: ^tools/
|
||||
- id: ruff-format
|
||||
name: ruff format (core)
|
||||
files: ^core/
|
||||
- id: ruff-format
|
||||
name: ruff format (tools)
|
||||
files: ^tools/
|
||||
Vendored
+7
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"charliermarsh.ruff",
|
||||
"editorconfig.editorconfig",
|
||||
"ms-python.python"
|
||||
]
|
||||
}
|
||||
+2
-1
@@ -25,8 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
### Removed
|
||||
- N/A
|
||||
|
||||
|
||||
### Fixed
|
||||
- N/A
|
||||
- tools: Fixed web_scrape tool attempting to parse non-HTML content (PDF, JSON) as HTML (#487)
|
||||
|
||||
### Security
|
||||
- N/A
|
||||
|
||||
+54
-6
@@ -6,6 +6,41 @@ Thank you for your interest in contributing to the Aden Agent Framework! This do
|
||||
|
||||
By participating in this project, you agree to abide by our [Code of Conduct](CODE_OF_CONDUCT.md).
|
||||
|
||||
## Contributor License Agreement
|
||||
|
||||
By submitting a Pull Request, you agree that your contributions will be licensed under the Aden Agent Framework license.
|
||||
|
||||
## Issue Assignment Policy
|
||||
|
||||
To prevent duplicate work and respect contributors' time, we require issue assignment before submitting PRs.
|
||||
|
||||
### How to Claim an Issue
|
||||
|
||||
1. **Find an Issue:** Browse existing issues or create a new one
|
||||
2. **Claim It:** Leave a comment (e.g., *"I'd like to work on this!"*)
|
||||
3. **Wait for Assignment:** A maintainer will assign you within 24 hours
|
||||
4. **Submit Your PR:** Once assigned, you're ready to contribute
|
||||
|
||||
> **Note:** PRs for unassigned issues may be delayed or closed if someone else was already assigned.
|
||||
|
||||
### The 5-Day Momentum Rule
|
||||
|
||||
To keep the project moving, issues with **no activity for 5 days** (no PR or status update) will be unassigned. If you need more time, just drop a quick comment!
|
||||
|
||||
### Exceptions (No Assignment Needed)
|
||||
|
||||
You may submit PRs without prior assignment for:
|
||||
- **Documentation:** Fixing typos or clarifying instructions — add the `documentation` label or include `doc`/`docs` in your PR title to bypass the linked issue requirement
|
||||
- **Micro-fixes:** Add the `micro-fix` label or include `micro-fix` in your PR title to bypass the linked issue requirement. Micro-fixes must meet **all** qualification criteria:
|
||||
|
||||
| Qualifies | Disqualifies |
|
||||
|-----------|--------------|
|
||||
| < 20 lines changed | Any functional bug fix |
|
||||
| Typos & Documentation & Linting | Refactoring for "clean code" |
|
||||
| No logic/API/DB changes | New features (even tiny ones) |
|
||||
|
||||
If a high-quality PR is submitted for a "stale" assigned issue (no activity for 7+ days), we may proceed with the submitted code.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Fork the repository
|
||||
@@ -29,6 +64,12 @@ python -c "import framework; import aden_tools; print('✓ Setup complete')"
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
> **Windows Users:**
|
||||
> If you are on native Windows, it is recommended to use **WSL (Windows Subsystem for Linux)**.
|
||||
> Alternatively, make sure to run PowerShell or Git Bash with Python 3.11+ installed, and disable "App Execution Aliases" in Windows settings.
|
||||
|
||||
> **Tip:** Installing Claude Code skills is optional for running existing agents, but required if you plan to **build new agents**.
|
||||
|
||||
## Commit Convention
|
||||
|
||||
We follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
@@ -59,11 +100,12 @@ docs(readme): update installation instructions
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
1. Update documentation if needed
|
||||
2. Add tests for new functionality
|
||||
3. Ensure all tests pass
|
||||
4. Update the CHANGELOG.md if applicable
|
||||
5. Request review from maintainers
|
||||
1. **Get assigned to the issue first** (see [Issue Assignment Policy](#issue-assignment-policy))
|
||||
2. Update documentation if needed
|
||||
3. Add tests for new functionality
|
||||
4. Ensure all tests pass
|
||||
5. Update the CHANGELOG.md if applicable
|
||||
6. Request review from maintainers
|
||||
|
||||
### PR Title Format
|
||||
|
||||
@@ -92,6 +134,12 @@ feat(component): add new feature description
|
||||
|
||||
## Testing
|
||||
|
||||
> **Note:** When testing agents in `exports/`, always set PYTHONPATH:
|
||||
>
|
||||
> ```bash
|
||||
> PYTHONPATH=core:exports python -m agent_name test
|
||||
> ```
|
||||
|
||||
```bash
|
||||
# Run all tests for the framework
|
||||
cd core && python -m pytest
|
||||
@@ -107,4 +155,4 @@ PYTHONPATH=core:exports python -m agent_name test
|
||||
|
||||
Feel free to open an issue for questions or join our [Discord community](https://discord.com/invite/MXE49hrKDk).
|
||||
|
||||
Thank you for contributing!
|
||||
Thank you for contributing!
|
||||
+24
-65
@@ -99,10 +99,13 @@ Get API keys:
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
This installs:
|
||||
This installs agent-related Claude Code skills:
|
||||
|
||||
- `/building-agents` - Build new goal-driven agents
|
||||
- `/testing-agent` - Test agents with evaluation framework
|
||||
- `/building-agents-core` - Fundamental agent concepts
|
||||
- `/building-agents-construction` - Step-by-step agent building
|
||||
- `/building-agents-patterns` - Best practices and design patterns
|
||||
- `/testing-agent` - Test and validate agents
|
||||
- `/agent-workflow` - End-to-end guided workflow
|
||||
|
||||
### Verify Setup
|
||||
|
||||
@@ -132,15 +135,22 @@ hive/ # Repository root
|
||||
│ └── CODEOWNERS # Auto-assign reviewers
|
||||
│
|
||||
├── .claude/ # Claude Code Skills
|
||||
│ └── skills/
|
||||
│ ├── building-agents/ # Skills for building agents
|
||||
│ │ ├── SKILL.md # Main skill definition
|
||||
│ │ ├── building-agents-core/
|
||||
│ │ ├── building-agents-patterns/
|
||||
│ │ └── building-agents-construction/
|
||||
│ └── skills/ # Skills for building
|
||||
│ ├── building-agents-core/
|
||||
| | ├── SKILL.md # Main skill definition
|
||||
│ | └── examples
|
||||
│ ├── building-agents-patterns/
|
||||
| | ├── SKILL.md
|
||||
│ | └── examples
|
||||
│ ├── building-agents-construction/
|
||||
| | ├── SKILL.md
|
||||
│ | └── examples
|
||||
│ ├── testing-agent/ # Skills for testing agents
|
||||
│ │ └── SKILL.md
|
||||
│ └── agent-workflow/ # Complete workflow orchestration
|
||||
│ │ ├── SKILL.md
|
||||
│ | └── examples
|
||||
│ └── agent-workflow/ # Complete workflow
|
||||
| ├── SKILL.md
|
||||
│ └── examples
|
||||
│
|
||||
├── core/ # CORE FRAMEWORK PACKAGE
|
||||
│ ├── framework/ # Main package code
|
||||
@@ -213,7 +223,7 @@ The fastest way to build agents is using the Claude Code skills:
|
||||
./quickstart.sh
|
||||
|
||||
# Build a new agent
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Test the agent
|
||||
claude> /testing-agent
|
||||
@@ -224,7 +234,7 @@ claude> /testing-agent
|
||||
1. **Define Your Goal**
|
||||
|
||||
```
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
Enter goal: "Build an agent that processes customer support tickets"
|
||||
```
|
||||
|
||||
@@ -511,30 +521,7 @@ chore(deps): update React to 18.2.0
|
||||
|
||||
## Debugging
|
||||
|
||||
### Frontend Debugging
|
||||
|
||||
**React Developer Tools:**
|
||||
|
||||
1. Install the [React DevTools browser extension](https://react.dev/learn/react-developer-tools)
|
||||
2. Open browser DevTools → React tab
|
||||
3. Inspect component tree, props, state, and hooks
|
||||
|
||||
**VS Code Debugging:**
|
||||
|
||||
1. Add Chrome debug configuration to `.vscode/launch.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"name": "Debug Frontend",
|
||||
"url": "http://localhost:3000",
|
||||
"webRoot": "${workspaceFolder}/honeycomb/src"
|
||||
}
|
||||
```
|
||||
|
||||
2. Start the dev server: `npm run dev -w honeycomb`
|
||||
3. Press F5 in VS Code
|
||||
|
||||
### Backend Debugging
|
||||
|
||||
@@ -594,7 +581,7 @@ pip install -e .
|
||||
|
||||
```bash
|
||||
# Option 1: Use Claude Code skill (recommended)
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Option 2: Create manually
|
||||
# Note: exports/ is initially empty (gitignored). Create your agent directory:
|
||||
@@ -709,14 +696,6 @@ kill -9 <PID>
|
||||
# Or change ports in config.yaml and regenerate
|
||||
```
|
||||
|
||||
### Node Modules Issues
|
||||
|
||||
```bash
|
||||
# Clean everything and reinstall
|
||||
npm run clean
|
||||
rm -rf node_modules package-lock.json
|
||||
npm install
|
||||
```
|
||||
|
||||
### Docker Issues
|
||||
|
||||
@@ -728,15 +707,7 @@ docker compose build --no-cache
|
||||
docker compose up
|
||||
```
|
||||
|
||||
### TypeScript Errors After Pull
|
||||
|
||||
```bash
|
||||
# Rebuild TypeScript
|
||||
npm run build
|
||||
|
||||
# Or restart TS server in VS Code
|
||||
# Cmd/Ctrl + Shift + P → "TypeScript: Restart TS Server"
|
||||
```
|
||||
|
||||
### Environment Variables Not Loading
|
||||
|
||||
@@ -746,24 +717,12 @@ npm run generate:env
|
||||
|
||||
# Verify files exist
|
||||
cat .env
|
||||
cat honeycomb/.env
|
||||
cat hive/.env
|
||||
|
||||
# Restart dev servers after changing env
|
||||
```
|
||||
|
||||
### Tests Failing
|
||||
|
||||
```bash
|
||||
# Run with verbose output
|
||||
npm run test -w honeycomb -- --reporter=verbose
|
||||
|
||||
# Run single test file
|
||||
npm run test -w honeycomb -- src/components/Button.test.tsx
|
||||
|
||||
# Clear test cache
|
||||
npm run test -w honeycomb -- --clearCache
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
+122
-13
@@ -9,6 +9,10 @@ Complete setup guide for building and running goal-driven agents with the Aden A
|
||||
./scripts/setup-python.sh
|
||||
```
|
||||
|
||||
> **Note for Windows Users:**
|
||||
> Running the setup script on native Windows shells (PowerShell / Git Bash) may sometimes fail due to Python App Execution Aliases.
|
||||
> It is **strongly recommended to use WSL (Windows Subsystem for Linux)** for a smoother setup experience.
|
||||
|
||||
This will:
|
||||
|
||||
- Check Python version (requires 3.11+)
|
||||
@@ -17,6 +21,26 @@ This will:
|
||||
- Fix package compatibility issues (openai + litellm)
|
||||
- Verify all installations
|
||||
|
||||
## Alpine Linux Setup
|
||||
|
||||
If you are using Alpine Linux (e.g., inside a Docker container), you must install system dependencies and use a virtual environment before running the setup script:
|
||||
|
||||
1. Install System Dependencies:
|
||||
```bash
|
||||
apk update
|
||||
apk add bash git python3 py3-pip nodejs npm curl build-base python3-dev linux-headers libffi-dev
|
||||
```
|
||||
2. Set up Virtual Environment (Required for Python 3.12+):
|
||||
```
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install --upgrade pip setuptools wheel
|
||||
```
|
||||
3. Run the Quickstart Script:
|
||||
```
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
## Manual Setup (Alternative)
|
||||
|
||||
If you prefer to set up manually or the script fails:
|
||||
@@ -50,6 +74,9 @@ python -c "import aden_tools; print('✓ aden_tools OK')"
|
||||
python -c "import litellm; print('✓ litellm OK')"
|
||||
```
|
||||
|
||||
> **Windows Tip:**
|
||||
> On Windows, if the verification commands fail, ensure you are running them in **WSL** or after **disabling Python App Execution Aliases** in Windows Settings → Apps → App Execution Aliases.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Python Version
|
||||
@@ -63,6 +90,7 @@ python -c "import litellm; print('✓ litellm OK')"
|
||||
- pip (latest version)
|
||||
- 2GB+ RAM
|
||||
- Internet connection (for LLM API calls)
|
||||
- For Windows users: WSL 2 is recommended for full compatibility.
|
||||
|
||||
### API Keys (Optional)
|
||||
|
||||
@@ -114,44 +142,125 @@ PYTHONPATH=core:exports python -m outbound_sales_agent validate
|
||||
PYTHONPATH=core:exports python -m personal_assistant_agent run --input '{...}'
|
||||
```
|
||||
|
||||
## Building New Agents
|
||||
## Building New Agents and Run Flow
|
||||
|
||||
Use Claude Code CLI with the agent building skills:
|
||||
Build and run an agent using Claude Code CLI with the agent building skills:
|
||||
|
||||
### 1. Install Skills (One-time)
|
||||
### 1. Install Claude Skills (One-time)
|
||||
|
||||
```bash
|
||||
./quickstart.sh
|
||||
```
|
||||
|
||||
This installs:
|
||||
This installs agent-related Claude Code skills:
|
||||
|
||||
- `/building-agents` - Build new agents
|
||||
- `/testing-agent` - Test agents
|
||||
- `/building-agents-construction` - Step-by-step build guide
|
||||
- `/building-agents-core` - Fundamental concepts
|
||||
- `/building-agents-patterns` - Best practices
|
||||
- `/testing-agent` - Test and validate agents
|
||||
- `/agent-workflow` - Complete workflow
|
||||
|
||||
### 2. Build an Agent
|
||||
|
||||
```
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
```
|
||||
|
||||
Follow the prompts to:
|
||||
|
||||
1. Define your agent's goal
|
||||
2. Design the workflow nodes
|
||||
3. Connect edges
|
||||
4. Generate the agent package
|
||||
3. Connect nodes with edges
|
||||
4. Generate the agent package under `exports/`
|
||||
|
||||
### 3. Test Your Agent
|
||||
This step creates the initial agent structure required for further development.
|
||||
|
||||
### 3. Define Agent Logic
|
||||
|
||||
```
|
||||
claude> /building-agents-core
|
||||
```
|
||||
|
||||
Follow the prompts to:
|
||||
|
||||
1. Understand the agent architecture and file structure
|
||||
2. Define the agent's goal, success criteria, and constraints
|
||||
3. Learn node types (LLM, tool-use, router, function)
|
||||
4. Discover and validate available tools before use
|
||||
|
||||
This step establishes the core concepts and rules needed before building an agent.
|
||||
|
||||
### 4. Apply Agent Patterns
|
||||
|
||||
```
|
||||
claude> /building-agents-patterns
|
||||
```
|
||||
|
||||
Follow the prompts to:
|
||||
|
||||
1. Apply best-practice agent design patterns
|
||||
2. Add pause/resume flows for multi-turn interactions
|
||||
3. Improve robustness with routing, fallbacks, and retries
|
||||
4. Avoid common anti-patterns during agent construction
|
||||
|
||||
This step helps optimize agent design before final testing.
|
||||
|
||||
### 5. Test Your Agent
|
||||
|
||||
```
|
||||
claude> /testing-agent
|
||||
```
|
||||
Follow the prompts to:
|
||||
|
||||
Creates comprehensive test suites for your agent.
|
||||
1. Generate test guidelines for constraints and success criteria
|
||||
2. Write agent tests directly under `exports/{agent}/tests/`
|
||||
3. Run goal-based evaluation tests
|
||||
4. Debug failing tests and iterate on agent improvements
|
||||
|
||||
This step verifies that the agent meets its goals before production use.
|
||||
|
||||
### 6. Agent Development Workflow (End-to-End)
|
||||
|
||||
```
|
||||
claude> /agent-workflow
|
||||
```
|
||||
|
||||
Follow the guided flow to:
|
||||
|
||||
1. Understand core agent concepts (optional)
|
||||
2. Build the agent structure step by step
|
||||
3. Apply best-practice design patterns (optional)
|
||||
4. Test and validate the agent against its goals
|
||||
|
||||
This workflow orchestrates all agent-building skills to take you from idea → production-ready agent.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "externally-managed-environment" error (PEP 668)
|
||||
|
||||
**Cause:** Python 3.12+ on macOS/Homebrew, WSL, or some Linux distros prevents system-wide pip installs.
|
||||
|
||||
**Solution:** Create and use a virtual environment:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
python3 -m venv .venv
|
||||
|
||||
# Activate it
|
||||
source .venv/bin/activate # macOS/Linux
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Then run setup
|
||||
./scripts/setup-python.sh
|
||||
```
|
||||
|
||||
Always activate the venv before running agents:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
PYTHONPATH=core:exports python -m your_agent_name demo
|
||||
```
|
||||
|
||||
### "ModuleNotFoundError: No module named 'framework'"
|
||||
|
||||
**Solution:** Install the core package:
|
||||
@@ -188,7 +297,7 @@ pip install --upgrade "openai>=1.0.0"
|
||||
|
||||
**Cause:** Not running from project root or missing PYTHONPATH
|
||||
|
||||
**Solution:** Ensure you're in `/home/timothy/oss/hive/` and use:
|
||||
**Solution:** Ensure you're in the project root directory and use:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=core:exports python -m support_ticket_agent validate
|
||||
@@ -256,7 +365,7 @@ This design allows agents in `exports/` to be:
|
||||
### 2. Build Agent (Claude Code)
|
||||
|
||||
```
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
Enter goal: "Build an agent that processes customer support tickets"
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
.PHONY: lint format check test install-hooks help
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \
|
||||
awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
lint: ## Run ruff linter (with auto-fix)
|
||||
cd core && ruff check --fix .
|
||||
cd tools && ruff check --fix .
|
||||
|
||||
format: ## Run ruff formatter
|
||||
cd core && ruff format .
|
||||
cd tools && ruff format .
|
||||
|
||||
check: ## Run all checks without modifying files (CI-safe)
|
||||
cd core && ruff check .
|
||||
cd tools && ruff check .
|
||||
cd core && ruff format --check .
|
||||
cd tools && ruff format --check .
|
||||
|
||||
test: ## Run all tests
|
||||
cd core && python -m pytest tests/ -v
|
||||
|
||||
install-hooks: ## Install pre-commit hooks
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
@@ -4,12 +4,12 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="README.md">English</a> |
|
||||
<a href="README.zh-CN.md">简体中文</a> |
|
||||
<a href="README.es.md">Español</a> |
|
||||
<a href="README.pt.md">Português</a> |
|
||||
<a href="README.ja.md">日本語</a> |
|
||||
<a href="README.ru.md">Русский</a> |
|
||||
<a href="README.ko.md">한국어</a>
|
||||
<a href="docs/i18n/zh-CN.md">简体中文</a> |
|
||||
<a href="docs/i18n/es.md">Español</a> |
|
||||
<a href="docs/i18n/pt.md">Português</a> |
|
||||
<a href="docs/i18n/ja.md">日本語</a> |
|
||||
<a href="docs/i18n/ru.md">Русский</a> |
|
||||
<a href="docs/i18n/ko.md">한국어</a>
|
||||
</p>
|
||||
|
||||
[](https://github.com/adenhq/hive/blob/main/LICENSE)
|
||||
@@ -91,7 +91,7 @@ This installs:
|
||||
./quickstart.sh
|
||||
|
||||
# Build an agent using Claude Code
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Test your agent
|
||||
claude> /testing-agent
|
||||
@@ -102,6 +102,15 @@ PYTHONPATH=core:exports python -m your_agent_name run --input '{...}'
|
||||
|
||||
**[📖 Complete Setup Guide](ENVIRONMENT_SETUP.md)** - Detailed instructions for agent development
|
||||
|
||||
### Cursor IDE Support
|
||||
|
||||
Skills are also available in Cursor. To enable:
|
||||
|
||||
1. Open Command Palette (`Cmd+Shift+P` / `Ctrl+Shift+P`)
|
||||
2. Run `MCP: Enable` to enable MCP servers
|
||||
3. Restart Cursor to load the MCP servers from `.cursor/mcp.json`
|
||||
4. Type `/` in Agent chat and search for skills (e.g., `/building-agents-construction`)
|
||||
|
||||
## Features
|
||||
|
||||
- **Goal-Driven Development** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
@@ -226,6 +235,7 @@ hive/
|
||||
├── docs/ # Documentation and guides
|
||||
├── scripts/ # Build and utility scripts
|
||||
├── .claude/ # Claude Code skills for building agents
|
||||
├── .cursor/ # Cursor IDE skills (symlinks to .claude/skills)
|
||||
├── ENVIRONMENT_SETUP.md # Python setup guide for agent development
|
||||
├── DEVELOPER.md # Developer guide
|
||||
├── CONTRIBUTING.md # Contribution guidelines
|
||||
@@ -248,7 +258,7 @@ For building and running goal-driven agents with the framework:
|
||||
# - All dependencies
|
||||
|
||||
# Build new agents using Claude Code skills
|
||||
claude> /building-agents
|
||||
claude> /building-agents-construction
|
||||
|
||||
# Test agents
|
||||
claude> /testing-agent
|
||||
@@ -264,7 +274,7 @@ See [ENVIRONMENT_SETUP.md](ENVIRONMENT_SETUP.md) for complete setup instructions
|
||||
- **[Developer Guide](DEVELOPER.md)** - Comprehensive guide for developers
|
||||
- [Getting Started](docs/getting-started.md) - Quick setup instructions
|
||||
- [Configuration Guide](docs/configuration.md) - All configuration options
|
||||
- [Architecture Overview](docs/architecture.md) - System design and structure
|
||||
- [Architecture Overview](docs/architecture/README.md) - System design and structure
|
||||
|
||||
## Roadmap
|
||||
|
||||
@@ -300,11 +310,14 @@ We use [Discord](https://discord.com/invite/MXE49hrKDk) for support, feature req
|
||||
|
||||
We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
|
||||
1. Fork the repository
|
||||
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
|
||||
3. Commit your changes (`git commit -m 'Add amazing feature'`)
|
||||
4. Push to the branch (`git push origin feature/amazing-feature`)
|
||||
5. Open a Pull Request
|
||||
**Important:** Please get assigned to an issue before submitting a PR. Comment on an issue to claim it, and a maintainer will assign you within 24 hours. This helps prevent duplicate work.
|
||||
|
||||
1. Find or create an issue and get assigned
|
||||
2. Fork the repository
|
||||
3. Create your feature branch (`git checkout -b feature/amazing-feature`)
|
||||
4. Commit your changes (`git commit -m 'Add amazing feature'`)
|
||||
5. Push to the branch (`git push origin feature/amazing-feature`)
|
||||
6. Open a Pull Request
|
||||
|
||||
## Join Our Team
|
||||
|
||||
@@ -328,7 +341,7 @@ No. Aden is built from the ground up with no dependencies on LangChain, CrewAI,
|
||||
|
||||
**Q: What LLM providers does Aden support?**
|
||||
|
||||
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
Aden supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
|
||||
**Q: Can I use Aden with local AI models like Ollama?**
|
||||
|
||||
@@ -348,7 +361,7 @@ Aden collects telemetry data for monitoring and observability purposes, includin
|
||||
|
||||
**Q: What deployment options does Aden support?**
|
||||
|
||||
Aden supports Docker Compose deployment out of the box, with both production and development configurations. Self-hosted deployments work on any infrastructure supporting Docker. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
|
||||
Aden supports self-hosted deployments via Python packages. See the [Environment Setup Guide](ENVIRONMENT_SETUP.md) for installation instructions. Cloud deployment options and Kubernetes-ready configurations are on the roadmap.
|
||||
|
||||
**Q: Can Aden handle complex, production-scale use cases?**
|
||||
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
Product Roadmap
|
||||
# Product Roadmap
|
||||
|
||||
Aden Agent Framework aims to help developers build outcome oriented, self-adaptive agents. Please find our roadmap here
|
||||
|
||||
|
||||
+1
-1
@@ -145,7 +145,7 @@ python -m framework test-debug <agent_path> <test_name>
|
||||
python -m framework test-list <goal_id>
|
||||
```
|
||||
|
||||
For detailed testing workflows, see the [testing-agent skill](.claude/skills/testing-agent/SKILL.md).
|
||||
For detailed testing workflows, see the [testing-agent skill](../.claude/skills/testing-agent/SKILL.md).
|
||||
|
||||
### Analyzing Agent Behavior with Builder
|
||||
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Minimal Manual Agent Example
|
||||
----------------------------
|
||||
This example demonstrates how to build and run an agent programmatically
|
||||
without using the Claude Code CLI or external LLM APIs.
|
||||
|
||||
It uses 'function' nodes to define logic in pure Python, making it perfect
|
||||
for understanding the core runtime loop:
|
||||
Setup -> Graph definition -> Execution -> Result
|
||||
|
||||
Run with:
|
||||
PYTHONPATH=core python core/examples/manual_agent.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from framework.graph import EdgeCondition, EdgeSpec, Goal, GraphSpec, NodeSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
# 1. Define Node Logic (Pure Python Functions)
|
||||
def greet(name: str) -> str:
|
||||
"""Generate a simple greeting."""
|
||||
return f"Hello, {name}!"
|
||||
|
||||
|
||||
def uppercase(greeting: str) -> str:
|
||||
"""Convert text to uppercase."""
|
||||
return greeting.upper()
|
||||
|
||||
|
||||
async def main():
|
||||
print("🚀 Setting up Manual Agent...")
|
||||
|
||||
# 2. Define the Goal
|
||||
# Every agent needs a goal with success criteria
|
||||
goal = Goal(
|
||||
id="greet-user",
|
||||
name="Greet User",
|
||||
description="Generate a friendly uppercase greeting",
|
||||
success_criteria=[
|
||||
{
|
||||
"id": "greeting_generated",
|
||||
"description": "Greeting produced",
|
||||
"metric": "custom",
|
||||
"target": "any",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# 3. Define Nodes
|
||||
# Nodes describe steps in the process
|
||||
node1 = NodeSpec(
|
||||
id="greeter",
|
||||
name="Greeter",
|
||||
description="Generates a simple greeting",
|
||||
node_type="function",
|
||||
function="greet", # Matches the registered function name
|
||||
input_keys=["name"],
|
||||
output_keys=["greeting"],
|
||||
)
|
||||
|
||||
node2 = NodeSpec(
|
||||
id="uppercaser",
|
||||
name="Uppercaser",
|
||||
description="Converts greeting to uppercase",
|
||||
node_type="function",
|
||||
function="uppercase",
|
||||
input_keys=["greeting"],
|
||||
output_keys=["final_greeting"],
|
||||
)
|
||||
|
||||
# 4. Define Edges
|
||||
# Edges define the flow between nodes
|
||||
edge1 = EdgeSpec(
|
||||
id="greet-to-upper",
|
||||
source="greeter",
|
||||
target="uppercaser",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
)
|
||||
|
||||
# 5. Create Graph
|
||||
# The graph works like a blueprint connecting nodes and edges
|
||||
graph = GraphSpec(
|
||||
id="greeting-agent",
|
||||
goal_id="greet-user",
|
||||
entry_node="greeter",
|
||||
terminal_nodes=["uppercaser"],
|
||||
nodes=[node1, node2],
|
||||
edges=[edge1],
|
||||
)
|
||||
|
||||
# 6. Initialize Runtime & Executor
|
||||
# Runtime handles state/memory; Executor runs the graph
|
||||
from pathlib import Path
|
||||
|
||||
runtime = Runtime(storage_path=Path("./agent_logs"))
|
||||
executor = GraphExecutor(runtime=runtime)
|
||||
|
||||
# 7. Register Function Implementations
|
||||
# Connect string names in NodeSpecs to actual Python functions
|
||||
executor.register_function("greeter", greet)
|
||||
executor.register_function("uppercaser", uppercase)
|
||||
|
||||
# 8. Execute Agent
|
||||
print("▶ Executing agent with input: name='Alice'...")
|
||||
|
||||
result = await executor.execute(graph=graph, goal=goal, input_data={"name": "Alice"})
|
||||
|
||||
# 9. Verify Results
|
||||
if result.success:
|
||||
print("\n✅ Success!")
|
||||
print(f"Path taken: {' -> '.join(result.path)}")
|
||||
print(f"Final output: {result.output.get('final_greeting')}")
|
||||
else:
|
||||
print(f"\n❌ Failed: {result.error}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Optional: Enable logging to see internal decision flow
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
asyncio.run(main())
|
||||
@@ -37,9 +37,9 @@ async def example_1_programmatic_registration():
|
||||
print(f"\nAvailable tools: {list(tools.keys())}")
|
||||
|
||||
# Run the agent with MCP tools available
|
||||
result = await runner.run({
|
||||
"objective": "Search for 'Claude AI' and summarize the top 3 results"
|
||||
})
|
||||
result = await runner.run(
|
||||
{"objective": "Search for 'Claude AI' and summarize the top 3 results"}
|
||||
)
|
||||
|
||||
print(f"\nAgent result: {result}")
|
||||
|
||||
@@ -78,10 +78,8 @@ async def example_3_config_file():
|
||||
|
||||
# Copy example config (in practice, you'd place this in your agent folder)
|
||||
import shutil
|
||||
shutil.copy(
|
||||
"examples/mcp_servers.json",
|
||||
test_agent_path / "mcp_servers.json"
|
||||
)
|
||||
|
||||
shutil.copy("examples/mcp_servers.json", test_agent_path / "mcp_servers.json")
|
||||
|
||||
# Load agent - MCP servers will be auto-discovered
|
||||
runner = AgentRunner.load(test_agent_path)
|
||||
@@ -110,18 +108,14 @@ async def example_4_custom_agent_with_mcp_tools():
|
||||
builder.set_goal(
|
||||
goal_id="web-researcher",
|
||||
name="Web Research Agent",
|
||||
description="Search the web and summarize findings"
|
||||
description="Search the web and summarize findings",
|
||||
)
|
||||
|
||||
# Add success criteria
|
||||
builder.add_success_criterion(
|
||||
"search-results",
|
||||
"Successfully retrieve at least 3 web search results"
|
||||
)
|
||||
builder.add_success_criterion(
|
||||
"summary",
|
||||
"Provide a clear, concise summary of the findings"
|
||||
"search-results", "Successfully retrieve at least 3 web search results"
|
||||
)
|
||||
builder.add_success_criterion("summary", "Provide a clear, concise summary of the findings")
|
||||
|
||||
# Add nodes that will use MCP tools
|
||||
builder.add_node(
|
||||
@@ -192,6 +186,7 @@ async def main():
|
||||
except Exception as e:
|
||||
print(f"\nError running example: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
@@ -22,22 +22,22 @@ The framework includes a Goal-Based Testing system (Goal → Agent → Eval):
|
||||
See `framework.testing` for details.
|
||||
"""
|
||||
|
||||
from framework.schemas.decision import Decision, Option, Outcome, DecisionEvaluation
|
||||
from framework.schemas.run import Run, RunSummary, Problem
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.builder.query import BuilderQuery
|
||||
from framework.llm import LLMProvider, AnthropicProvider
|
||||
from framework.runner import AgentRunner, AgentOrchestrator
|
||||
from framework.llm import AnthropicProvider, LLMProvider
|
||||
from framework.runner import AgentOrchestrator, AgentRunner
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.schemas.decision import Decision, DecisionEvaluation, Option, Outcome
|
||||
from framework.schemas.run import Problem, Run, RunSummary
|
||||
|
||||
# Testing framework
|
||||
from framework.testing import (
|
||||
ApprovalStatus,
|
||||
DebugTool,
|
||||
ErrorCategory,
|
||||
Test,
|
||||
TestResult,
|
||||
TestSuiteResult,
|
||||
TestStorage,
|
||||
ApprovalStatus,
|
||||
ErrorCategory,
|
||||
DebugTool,
|
||||
TestSuiteResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
from framework.builder.query import BuilderQuery
|
||||
from framework.builder.workflow import (
|
||||
GraphBuilder,
|
||||
BuildSession,
|
||||
BuildPhase,
|
||||
ValidationResult,
|
||||
BuildSession,
|
||||
GraphBuilder,
|
||||
TestCase,
|
||||
TestResult,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -8,12 +8,12 @@ This is designed around the questions I need to answer:
|
||||
4. What should we change? (suggestions)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.schemas.decision import Decision
|
||||
from framework.schemas.run import Run, RunSummary, RunStatus
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
from framework.storage.backend import FileStorage
|
||||
|
||||
|
||||
@@ -196,10 +196,7 @@ class BuilderQuery:
|
||||
break
|
||||
|
||||
# Extract problems
|
||||
problems = [
|
||||
f"[{p.severity}] {p.description}"
|
||||
for p in run.problems
|
||||
]
|
||||
problems = [f"[{p.severity}] {p.description}" for p in run.problems]
|
||||
|
||||
# Generate suggestions based on the failure
|
||||
suggestions = self._generate_suggestions(run, failed_decisions)
|
||||
@@ -253,11 +250,7 @@ class BuilderQuery:
|
||||
error = decision.outcome.error or "Unknown error"
|
||||
failure_counts[error] += 1
|
||||
|
||||
common_failures = sorted(
|
||||
failure_counts.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:5]
|
||||
common_failures = sorted(failure_counts.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
|
||||
# Find problematic nodes
|
||||
node_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"total": 0, "failed": 0})
|
||||
@@ -328,34 +321,45 @@ class BuilderQuery:
|
||||
|
||||
# Suggestion: Fix problematic nodes
|
||||
for node_id, failure_rate in patterns.problematic_nodes:
|
||||
suggestions.append({
|
||||
"type": "node_improvement",
|
||||
"target": node_id,
|
||||
"reason": f"Node has {failure_rate:.1%} failure rate",
|
||||
"recommendation": f"Review and improve node '{node_id}' - high failure rate suggests prompt or tool issues",
|
||||
"priority": "high" if failure_rate > 0.3 else "medium",
|
||||
})
|
||||
suggestions.append(
|
||||
{
|
||||
"type": "node_improvement",
|
||||
"target": node_id,
|
||||
"reason": f"Node has {failure_rate:.1%} failure rate",
|
||||
"recommendation": (
|
||||
f"Review and improve node '{node_id}' - "
|
||||
"high failure rate suggests prompt or tool issues"
|
||||
),
|
||||
"priority": "high" if failure_rate > 0.3 else "medium",
|
||||
}
|
||||
)
|
||||
|
||||
# Suggestion: Address common failures
|
||||
for failure, count in patterns.common_failures:
|
||||
if count >= 2:
|
||||
suggestions.append({
|
||||
"type": "error_handling",
|
||||
"target": failure,
|
||||
"reason": f"Error occurred {count} times",
|
||||
"recommendation": f"Add handling for: {failure}",
|
||||
"priority": "high" if count >= 5 else "medium",
|
||||
})
|
||||
suggestions.append(
|
||||
{
|
||||
"type": "error_handling",
|
||||
"target": failure,
|
||||
"reason": f"Error occurred {count} times",
|
||||
"recommendation": f"Add handling for: {failure}",
|
||||
"priority": "high" if count >= 5 else "medium",
|
||||
}
|
||||
)
|
||||
|
||||
# Suggestion: Overall success rate
|
||||
if patterns.success_rate < 0.8:
|
||||
suggestions.append({
|
||||
"type": "architecture",
|
||||
"target": goal_id,
|
||||
"reason": f"Goal success rate is only {patterns.success_rate:.1%}",
|
||||
"recommendation": "Consider restructuring the agent graph or improving goal definition",
|
||||
"priority": "high",
|
||||
})
|
||||
suggestions.append(
|
||||
{
|
||||
"type": "architecture",
|
||||
"target": goal_id,
|
||||
"reason": f"Goal success rate is only {patterns.success_rate:.1%}",
|
||||
"recommendation": (
|
||||
"Consider restructuring the agent graph or improving goal definition"
|
||||
),
|
||||
"priority": "high",
|
||||
}
|
||||
)
|
||||
|
||||
return suggestions
|
||||
|
||||
@@ -408,21 +412,22 @@ class BuilderQuery:
|
||||
alternatives = [o for o in decision.options if o.id != decision.chosen_option_id]
|
||||
if alternatives:
|
||||
alt_desc = alternatives[0].description
|
||||
chosen_desc = chosen.description if chosen else "unknown"
|
||||
suggestions.append(
|
||||
f"Consider alternative: '{alt_desc}' instead of '{chosen.description if chosen else 'unknown'}'"
|
||||
f"Consider alternative: '{alt_desc}' instead of '{chosen_desc}'"
|
||||
)
|
||||
|
||||
# Check for missing context
|
||||
if not decision.input_context:
|
||||
suggestions.append(
|
||||
f"Decision '{decision.intent}' had no input context - ensure relevant data is passed"
|
||||
f"Decision '{decision.intent}' had no input context - "
|
||||
"ensure relevant data is passed"
|
||||
)
|
||||
|
||||
# Check for constraint issues
|
||||
if decision.active_constraints:
|
||||
suggestions.append(
|
||||
f"Review constraints: {', '.join(decision.active_constraints)} - may be too restrictive"
|
||||
)
|
||||
constraints = ", ".join(decision.active_constraints)
|
||||
suggestions.append(f"Review constraints: {constraints} - may be too restrictive")
|
||||
|
||||
# Check for reported problems with suggestions
|
||||
for problem in run.problems:
|
||||
@@ -471,15 +476,14 @@ class BuilderQuery:
|
||||
|
||||
# Decision count difference
|
||||
if len(run1.decisions) != len(run2.decisions):
|
||||
differences.append(
|
||||
f"Decision count: {len(run1.decisions)} vs {len(run2.decisions)}"
|
||||
)
|
||||
differences.append(f"Decision count: {len(run1.decisions)} vs {len(run2.decisions)}")
|
||||
|
||||
# Find first divergence point
|
||||
for i, (d1, d2) in enumerate(zip(run1.decisions, run2.decisions)):
|
||||
for i, (d1, d2) in enumerate(zip(run1.decisions, run2.decisions, strict=False)):
|
||||
if d1.chosen_option_id != d2.chosen_option_id:
|
||||
differences.append(
|
||||
f"Diverged at decision {i}: chose '{d1.chosen_option_id}' vs '{d2.chosen_option_id}'"
|
||||
f"Diverged at decision {i}: "
|
||||
f"chose '{d1.chosen_option_id}' vs '{d2.chosen_option_id}'"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
@@ -13,32 +13,35 @@ Each step requires validation and human approval before proceeding.
|
||||
You cannot skip steps or bypass validation.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.graph.edge import EdgeSpec, EdgeCondition, GraphSpec
|
||||
|
||||
|
||||
class BuildPhase(str, Enum):
|
||||
"""Current phase of the build process."""
|
||||
INIT = "init" # Just started
|
||||
GOAL_DRAFT = "goal_draft" # Drafting goal
|
||||
|
||||
INIT = "init" # Just started
|
||||
GOAL_DRAFT = "goal_draft" # Drafting goal
|
||||
GOAL_APPROVED = "goal_approved" # Goal approved
|
||||
ADDING_NODES = "adding_nodes" # Adding nodes
|
||||
ADDING_EDGES = "adding_edges" # Adding edges
|
||||
TESTING = "testing" # Running tests
|
||||
APPROVED = "approved" # Fully approved
|
||||
EXPORTED = "exported" # Exported to file
|
||||
ADDING_NODES = "adding_nodes" # Adding nodes
|
||||
ADDING_EDGES = "adding_edges" # Adding edges
|
||||
TESTING = "testing" # Running tests
|
||||
APPROVED = "approved" # Fully approved
|
||||
EXPORTED = "exported" # Exported to file
|
||||
|
||||
|
||||
class ValidationResult(BaseModel):
|
||||
"""Result of a validation check."""
|
||||
|
||||
valid: bool
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
@@ -47,6 +50,7 @@ class ValidationResult(BaseModel):
|
||||
|
||||
class TestCase(BaseModel):
|
||||
"""A test case for validating agent behavior."""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
input: dict[str, Any]
|
||||
@@ -56,6 +60,7 @@ class TestCase(BaseModel):
|
||||
|
||||
class TestResult(BaseModel):
|
||||
"""Result of running a test case."""
|
||||
|
||||
test_id: str
|
||||
passed: bool
|
||||
actual_output: Any = None
|
||||
@@ -69,6 +74,7 @@ class BuildSession(BaseModel):
|
||||
|
||||
Saved after each approved step so you can resume later.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
phase: BuildPhase = BuildPhase.INIT
|
||||
@@ -457,11 +463,14 @@ class GraphBuilder:
|
||||
|
||||
# Run the test
|
||||
import asyncio
|
||||
result = asyncio.run(executor.execute(
|
||||
graph=graph,
|
||||
goal=self.session.goal,
|
||||
input_data=test.input,
|
||||
))
|
||||
|
||||
result = asyncio.run(
|
||||
executor.execute(
|
||||
graph=graph,
|
||||
goal=self.session.goal,
|
||||
input_data=test.input,
|
||||
)
|
||||
)
|
||||
|
||||
# Check result
|
||||
passed = result.success
|
||||
@@ -515,12 +524,14 @@ class GraphBuilder:
|
||||
if not self._pending_validation.valid:
|
||||
return False
|
||||
|
||||
self.session.approvals.append({
|
||||
"phase": self.session.phase.value,
|
||||
"comment": comment,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"validation": self._pending_validation.model_dump(),
|
||||
})
|
||||
self.session.approvals.append(
|
||||
{
|
||||
"phase": self.session.phase.value,
|
||||
"comment": comment,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"validation": self._pending_validation.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
# Advance phase if appropriate
|
||||
if self.session.phase == BuildPhase.GOAL_DRAFT:
|
||||
@@ -554,11 +565,13 @@ class GraphBuilder:
|
||||
return False
|
||||
|
||||
self.session.phase = BuildPhase.APPROVED
|
||||
self.session.approvals.append({
|
||||
"phase": "final",
|
||||
"comment": comment,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
self.session.approvals.append(
|
||||
{
|
||||
"phase": "final",
|
||||
"comment": comment,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
self._save_session()
|
||||
return True
|
||||
@@ -630,69 +643,75 @@ class GraphBuilder:
|
||||
"""Generate Python code for the graph."""
|
||||
lines = [
|
||||
'"""',
|
||||
f'Generated agent: {self.session.name}',
|
||||
f'Generated at: {datetime.now().isoformat()}',
|
||||
f"Generated agent: {self.session.name}",
|
||||
f"Generated at: {datetime.now().isoformat()}",
|
||||
'"""',
|
||||
'',
|
||||
'from framework.graph import (',
|
||||
' Goal, SuccessCriterion, Constraint,',
|
||||
' NodeSpec, EdgeSpec, EdgeCondition,',
|
||||
')',
|
||||
'from framework.graph.edge import GraphSpec',
|
||||
'from framework.graph.goal import GoalStatus',
|
||||
'',
|
||||
'',
|
||||
'# Goal',
|
||||
"",
|
||||
"from framework.graph import (",
|
||||
" Goal, SuccessCriterion, Constraint,",
|
||||
" NodeSpec, EdgeSpec, EdgeCondition,",
|
||||
")",
|
||||
"from framework.graph.edge import GraphSpec",
|
||||
"from framework.graph.goal import GoalStatus",
|
||||
"",
|
||||
"",
|
||||
"# Goal",
|
||||
]
|
||||
|
||||
if self.session.goal:
|
||||
goal_json = self.session.goal.model_dump_json(indent=4)
|
||||
lines.append('GOAL = Goal.model_validate_json(\'\'\'')
|
||||
lines.append("GOAL = Goal.model_validate_json('''")
|
||||
lines.append(goal_json)
|
||||
lines.append("''')")
|
||||
else:
|
||||
lines.append('GOAL = None')
|
||||
lines.append("GOAL = None")
|
||||
|
||||
lines.extend([
|
||||
'',
|
||||
'',
|
||||
'# Nodes',
|
||||
'NODES = [',
|
||||
])
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"",
|
||||
"# Nodes",
|
||||
"NODES = [",
|
||||
]
|
||||
)
|
||||
|
||||
for node in self.session.nodes:
|
||||
node_json = node.model_dump_json(indent=4)
|
||||
lines.append(' NodeSpec.model_validate_json(\'\'\'')
|
||||
lines.append(" NodeSpec.model_validate_json('''")
|
||||
lines.append(node_json)
|
||||
lines.append(" '''),")
|
||||
|
||||
lines.extend([
|
||||
']',
|
||||
'',
|
||||
'',
|
||||
'# Edges',
|
||||
'EDGES = [',
|
||||
])
|
||||
lines.extend(
|
||||
[
|
||||
"]",
|
||||
"",
|
||||
"",
|
||||
"# Edges",
|
||||
"EDGES = [",
|
||||
]
|
||||
)
|
||||
|
||||
for edge in self.session.edges:
|
||||
edge_json = edge.model_dump_json(indent=4)
|
||||
lines.append(' EdgeSpec.model_validate_json(\'\'\'')
|
||||
lines.append(" EdgeSpec.model_validate_json('''")
|
||||
lines.append(edge_json)
|
||||
lines.append(" '''),")
|
||||
|
||||
lines.extend([
|
||||
']',
|
||||
'',
|
||||
'',
|
||||
'# Graph',
|
||||
])
|
||||
lines.extend(
|
||||
[
|
||||
"]",
|
||||
"",
|
||||
"",
|
||||
"# Graph",
|
||||
]
|
||||
)
|
||||
|
||||
graph_json = graph.model_dump_json(indent=4)
|
||||
lines.append('GRAPH = GraphSpec.model_validate_json(\'\'\'')
|
||||
lines.append("GRAPH = GraphSpec.model_validate_json('''")
|
||||
lines.append(graph_json)
|
||||
lines.append("''')")
|
||||
|
||||
return '\n'.join(lines)
|
||||
return "\n".join(lines)
|
||||
|
||||
# =========================================================================
|
||||
# SESSION MANAGEMENT
|
||||
@@ -743,7 +762,9 @@ class GraphBuilder:
|
||||
"tests": len(self.session.test_cases),
|
||||
"tests_passed": sum(1 for t in self.session.test_results if t.passed),
|
||||
"approvals": len(self.session.approvals),
|
||||
"pending_validation": self._pending_validation.model_dump() if self._pending_validation else None,
|
||||
"pending_validation": self._pending_validation.model_dump()
|
||||
if self._pending_validation
|
||||
else None,
|
||||
}
|
||||
|
||||
def show(self) -> str:
|
||||
@@ -755,11 +776,13 @@ class GraphBuilder:
|
||||
]
|
||||
|
||||
if self.session.goal:
|
||||
lines.extend([
|
||||
f"Goal: {self.session.goal.name}",
|
||||
f" {self.session.goal.description}",
|
||||
"",
|
||||
])
|
||||
lines.extend(
|
||||
[
|
||||
f"Goal: {self.session.goal.name}",
|
||||
f" {self.session.goal.description}",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if self.session.nodes:
|
||||
lines.append("Nodes:")
|
||||
|
||||
@@ -21,9 +21,7 @@ import sys
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Goal Agent - Build and run goal-driven agents"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Goal Agent - Build and run goal-driven agents")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="claude-haiku-4-5-20251001",
|
||||
@@ -34,10 +32,12 @@ def main():
|
||||
|
||||
# Register runner commands (run, info, validate, list, dispatch, shell)
|
||||
from framework.runner.cli import register_commands
|
||||
|
||||
register_commands(subparsers)
|
||||
|
||||
# Register testing commands (test-run, test-debug, test-list, test-stats)
|
||||
from framework.testing.cli import register_testing_commands
|
||||
|
||||
register_testing_commands(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Credential Store - Production-ready credential management for Hive.
|
||||
|
||||
This module provides secure credential storage with:
|
||||
- Key-vault structure: Credentials as objects with multiple keys
|
||||
- Template-based usage: {{cred.key}} patterns for injection
|
||||
- Bipartisan model: Store stores values, tools define usage
|
||||
- Provider system: Extensible lifecycle management (refresh, validate)
|
||||
- Multiple backends: Encrypted files, env vars, HashiCorp Vault
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore, CredentialObject
|
||||
|
||||
# Create store with encrypted storage
|
||||
store = CredentialStore.with_encrypted_storage("/var/hive/credentials")
|
||||
|
||||
# Get a credential
|
||||
api_key = store.get("brave_search")
|
||||
|
||||
# Resolve templates in headers
|
||||
headers = store.resolve_headers({
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}"
|
||||
})
|
||||
|
||||
# Save a new credential
|
||||
store.save_credential(CredentialObject(
|
||||
id="my_api",
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("xxx"))}
|
||||
))
|
||||
|
||||
For OAuth2 support:
|
||||
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config
|
||||
|
||||
For Aden server sync:
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
For Vault integration:
|
||||
from core.framework.credentials.vault import HashiCorpVaultStorage
|
||||
"""
|
||||
|
||||
from .models import (
|
||||
CredentialDecryptionError,
|
||||
CredentialError,
|
||||
CredentialKey,
|
||||
CredentialKeyNotFoundError,
|
||||
CredentialNotFoundError,
|
||||
CredentialObject,
|
||||
CredentialRefreshError,
|
||||
CredentialType,
|
||||
CredentialUsageSpec,
|
||||
CredentialValidationError,
|
||||
)
|
||||
from .provider import (
|
||||
BearerTokenProvider,
|
||||
CredentialProvider,
|
||||
StaticProvider,
|
||||
)
|
||||
from .storage import (
|
||||
CompositeStorage,
|
||||
CredentialStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
InMemoryStorage,
|
||||
)
|
||||
from .store import CredentialStore
|
||||
from .template import TemplateResolver
|
||||
|
||||
# Aden sync components (lazy import to avoid httpx dependency when not needed)
|
||||
# Usage: from core.framework.credentials.aden import AdenSyncProvider
|
||||
# Or: from core.framework.credentials import AdenSyncProvider
|
||||
try:
|
||||
from .aden import (
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
_ADEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
_ADEN_AVAILABLE = False
|
||||
|
||||
__all__ = [
|
||||
# Main store
|
||||
"CredentialStore",
|
||||
# Models
|
||||
"CredentialObject",
|
||||
"CredentialKey",
|
||||
"CredentialType",
|
||||
"CredentialUsageSpec",
|
||||
# Providers
|
||||
"CredentialProvider",
|
||||
"StaticProvider",
|
||||
"BearerTokenProvider",
|
||||
# Storage backends
|
||||
"CredentialStorage",
|
||||
"EncryptedFileStorage",
|
||||
"EnvVarStorage",
|
||||
"InMemoryStorage",
|
||||
"CompositeStorage",
|
||||
# Template resolution
|
||||
"TemplateResolver",
|
||||
# Exceptions
|
||||
"CredentialError",
|
||||
"CredentialNotFoundError",
|
||||
"CredentialKeyNotFoundError",
|
||||
"CredentialRefreshError",
|
||||
"CredentialValidationError",
|
||||
"CredentialDecryptionError",
|
||||
# Aden sync (optional - requires httpx)
|
||||
"AdenSyncProvider",
|
||||
"AdenCredentialClient",
|
||||
"AdenClientConfig",
|
||||
"AdenCachedStorage",
|
||||
]
|
||||
|
||||
# Track Aden availability for runtime checks
|
||||
ADEN_AVAILABLE = _ADEN_AVAILABLE
|
||||
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Aden Credential Sync.
|
||||
|
||||
Components for synchronizing credentials with the Aden authentication server.
|
||||
|
||||
The Aden server handles OAuth2 authorization flows and maintains refresh tokens.
|
||||
These components fetch and cache access tokens locally while delegating
|
||||
lifecycle management to Aden.
|
||||
|
||||
Components:
|
||||
- AdenCredentialClient: HTTP client for Aden API
|
||||
- AdenSyncProvider: CredentialProvider that syncs with Aden
|
||||
- AdenCachedStorage: Storage with local cache + Aden fallback
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.storage import EncryptedFileStorage
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# Configure (API key loaded from ADEN_API_KEY env var)
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url=os.environ["ADEN_API_URL"],
|
||||
))
|
||||
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage(),
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Initial sync
|
||||
provider.sync_all(store)
|
||||
|
||||
# Use normally
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
|
||||
See docs/aden-credential-sync.md for detailed documentation.
|
||||
"""
|
||||
|
||||
from .client import (
|
||||
AdenAuthenticationError,
|
||||
AdenClientConfig,
|
||||
AdenClientError,
|
||||
AdenCredentialClient,
|
||||
AdenCredentialResponse,
|
||||
AdenIntegrationInfo,
|
||||
AdenNotFoundError,
|
||||
AdenRateLimitError,
|
||||
AdenRefreshError,
|
||||
)
|
||||
from .provider import AdenSyncProvider
|
||||
from .storage import AdenCachedStorage
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"AdenCredentialClient",
|
||||
"AdenClientConfig",
|
||||
"AdenCredentialResponse",
|
||||
"AdenIntegrationInfo",
|
||||
# Client errors
|
||||
"AdenClientError",
|
||||
"AdenAuthenticationError",
|
||||
"AdenNotFoundError",
|
||||
"AdenRateLimitError",
|
||||
"AdenRefreshError",
|
||||
# Provider
|
||||
"AdenSyncProvider",
|
||||
# Storage
|
||||
"AdenCachedStorage",
|
||||
]
|
||||
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Aden Credential Client.
|
||||
|
||||
HTTP client for communicating with the Aden authentication server.
|
||||
The Aden server handles OAuth2 authorization flows and token management.
|
||||
This client fetches tokens and delegates refresh operations to Aden.
|
||||
|
||||
Usage:
|
||||
# API key loaded from ADEN_API_KEY environment variable by default
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url="https://hive.adenhq.com",
|
||||
))
|
||||
|
||||
# Or explicitly provide the API key
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url="https://hive.adenhq.com",
|
||||
api_key="your-api-key",
|
||||
))
|
||||
|
||||
# Fetch a credential
|
||||
response = client.get_credential("hubspot")
|
||||
if response:
|
||||
print(f"Token expires at: {response.expires_at}")
|
||||
|
||||
# Request a refresh
|
||||
refreshed = client.request_refresh("hubspot")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdenClientError(Exception):
|
||||
"""Base exception for Aden client errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdenAuthenticationError(AdenClientError):
|
||||
"""Raised when API key is invalid or revoked."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdenNotFoundError(AdenClientError):
|
||||
"""Raised when integration is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdenRefreshError(AdenClientError):
|
||||
"""Raised when token refresh fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
requires_reauthorization: bool = False,
|
||||
reauthorization_url: str | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.requires_reauthorization = requires_reauthorization
|
||||
self.reauthorization_url = reauthorization_url
|
||||
|
||||
|
||||
class AdenRateLimitError(AdenClientError):
|
||||
"""Raised when rate limited."""
|
||||
|
||||
def __init__(self, message: str, retry_after: int = 60):
|
||||
super().__init__(message)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdenClientConfig:
|
||||
"""Configuration for Aden API client."""
|
||||
|
||||
base_url: str
|
||||
"""Base URL of the Aden server (e.g., 'https://hive.adenhq.com')."""
|
||||
|
||||
api_key: str | None = None
|
||||
"""Agent's API key for authenticating with Aden.
|
||||
If not provided, loaded from ADEN_API_KEY environment variable."""
|
||||
|
||||
tenant_id: str | None = None
|
||||
"""Optional tenant ID for multi-tenant deployments."""
|
||||
|
||||
timeout: float = 30.0
|
||||
"""Request timeout in seconds."""
|
||||
|
||||
retry_attempts: int = 3
|
||||
"""Number of retry attempts for transient failures."""
|
||||
|
||||
retry_delay: float = 1.0
|
||||
"""Base delay between retries in seconds (exponential backoff)."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Load API key from environment if not provided."""
|
||||
if self.api_key is None:
|
||||
self.api_key = os.environ.get("ADEN_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Aden API key not provided. Either pass api_key to AdenClientConfig "
|
||||
"or set the ADEN_API_KEY environment variable."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdenCredentialResponse:
|
||||
"""Response from Aden server containing credential data."""
|
||||
|
||||
integration_id: str
|
||||
"""Unique identifier for the integration (e.g., 'hubspot')."""
|
||||
|
||||
integration_type: str
|
||||
"""Type of integration (e.g., 'hubspot', 'github', 'slack')."""
|
||||
|
||||
access_token: str
|
||||
"""The access token for API calls."""
|
||||
|
||||
token_type: str = "Bearer"
|
||||
"""Token type (usually 'Bearer')."""
|
||||
|
||||
expires_at: datetime | None = None
|
||||
"""When the access token expires (UTC)."""
|
||||
|
||||
scopes: list[str] = field(default_factory=list)
|
||||
"""OAuth2 scopes granted to this token."""
|
||||
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional integration-specific metadata."""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> AdenCredentialResponse:
|
||||
"""Create from API response dictionary."""
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
|
||||
|
||||
return cls(
|
||||
integration_id=data["integration_id"],
|
||||
integration_type=data["integration_type"],
|
||||
access_token=data["access_token"],
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_at=expires_at,
|
||||
scopes=data.get("scopes", []),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdenIntegrationInfo:
|
||||
"""Information about an available integration."""
|
||||
|
||||
integration_id: str
|
||||
integration_type: str
|
||||
status: str # "active", "requires_reauth", "expired"
|
||||
expires_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> AdenIntegrationInfo:
|
||||
"""Create from API response dictionary."""
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
|
||||
|
||||
return cls(
|
||||
integration_id=data["integration_id"],
|
||||
integration_type=data.get("provider", data["integration_id"]),
|
||||
status=data.get("status", "unknown"),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
||||
class AdenCredentialClient:
|
||||
"""
|
||||
HTTP client for Aden credential server.
|
||||
|
||||
Handles communication with the Aden authentication server,
|
||||
including fetching credentials, requesting refreshes, and
|
||||
reporting usage statistics.
|
||||
|
||||
The client automatically handles:
|
||||
- Retries with exponential backoff for transient failures
|
||||
- Proper error classification (auth, not found, rate limit, etc.)
|
||||
- Request headers for authentication and tenant isolation
|
||||
|
||||
Usage:
|
||||
# API key loaded from ADEN_API_KEY environment variable
|
||||
config = AdenClientConfig(
|
||||
base_url="https://hive.adenhq.com",
|
||||
)
|
||||
|
||||
client = AdenCredentialClient(config)
|
||||
|
||||
# Fetch a credential
|
||||
cred = client.get_credential("hubspot")
|
||||
if cred:
|
||||
headers = {"Authorization": f"Bearer {cred.access_token}"}
|
||||
|
||||
# List all integrations
|
||||
integrations = client.list_integrations()
|
||||
for info in integrations:
|
||||
print(f"{info.integration_id}: {info.status}")
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
"""
|
||||
|
||||
def __init__(self, config: AdenClientConfig):
|
||||
"""
|
||||
Initialize the Aden client.
|
||||
|
||||
Args:
|
||||
config: Client configuration including base URL and API key.
|
||||
"""
|
||||
self.config = config
|
||||
self._client: httpx.Client | None = None
|
||||
|
||||
def _get_client(self) -> httpx.Client:
|
||||
"""Get or create the HTTP client."""
|
||||
if self._client is None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "hive-credential-store/1.0",
|
||||
}
|
||||
|
||||
if self.config.tenant_id:
|
||||
headers["X-Tenant-ID"] = self.config.tenant_id
|
||||
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
timeout=self.config.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return self._client
|
||||
|
||||
def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> httpx.Response:
|
||||
"""Make a request with retry logic."""
|
||||
client = self._get_client()
|
||||
print(client.base_url)
|
||||
print(client.headers)
|
||||
print(method, path, kwargs)
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(self.config.retry_attempts):
|
||||
try:
|
||||
response = client.request(method, path, **kwargs)
|
||||
|
||||
# Handle specific error codes
|
||||
if response.status_code == 401:
|
||||
raise AdenAuthenticationError("Agent API key is invalid or revoked")
|
||||
|
||||
if response.status_code == 404:
|
||||
raise AdenNotFoundError(f"Integration not found: {path}")
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = int(response.headers.get("Retry-After", 60))
|
||||
raise AdenRateLimitError(
|
||||
"Rate limited by Aden server",
|
||||
retry_after=retry_after,
|
||||
)
|
||||
|
||||
if response.status_code == 400:
|
||||
data = response.json()
|
||||
if data.get("error") == "refresh_failed":
|
||||
raise AdenRefreshError(
|
||||
data.get("message", "Token refresh failed"),
|
||||
requires_reauthorization=data.get("requires_reauthorization", False),
|
||||
reauthorization_url=data.get("reauthorization_url"),
|
||||
)
|
||||
|
||||
# Success or other error
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
last_error = e
|
||||
if attempt < self.config.retry_attempts - 1:
|
||||
delay = self.config.retry_delay * (2**attempt)
|
||||
logger.warning(
|
||||
f"Aden request failed (attempt {attempt + 1}), retrying in {delay}s: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise AdenClientError(f"Failed to connect to Aden server: {e}") from e
|
||||
|
||||
except (
|
||||
AdenAuthenticationError,
|
||||
AdenNotFoundError,
|
||||
AdenRefreshError,
|
||||
AdenRateLimitError,
|
||||
):
|
||||
# Don't retry these errors
|
||||
raise
|
||||
|
||||
# Should not reach here, but just in case
|
||||
raise AdenClientError(
|
||||
f"Request failed after {self.config.retry_attempts} attempts"
|
||||
) from last_error
|
||||
|
||||
def get_credential(self, integration_id: str) -> AdenCredentialResponse | None:
|
||||
"""
|
||||
Fetch the current credential for an integration.
|
||||
|
||||
The Aden server may refresh the token internally if it's expired
|
||||
before returning it.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier (e.g., 'hubspot').
|
||||
|
||||
Returns:
|
||||
Credential response with access token, or None if not found.
|
||||
|
||||
Raises:
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
AdenClientError: For connection failures.
|
||||
"""
|
||||
try:
|
||||
response = self._request_with_retry("GET", f"/v1/credentials/{integration_id}")
|
||||
data = response.json()
|
||||
return AdenCredentialResponse.from_dict(data)
|
||||
except AdenNotFoundError:
|
||||
return None
|
||||
|
||||
def request_refresh(self, integration_id: str) -> AdenCredentialResponse:
|
||||
"""
|
||||
Request the Aden server to refresh the token.
|
||||
|
||||
Use this when the local store detects an expired or near-expiry token.
|
||||
The Aden server handles the actual OAuth2 refresh token flow.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier.
|
||||
|
||||
Returns:
|
||||
Credential response with new access token.
|
||||
|
||||
Raises:
|
||||
AdenRefreshError: If refresh fails (may require re-authorization).
|
||||
AdenNotFoundError: If integration not found.
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
AdenRateLimitError: If rate limited.
|
||||
"""
|
||||
response = self._request_with_retry("POST", f"/v1/credentials/{integration_id}/refresh")
|
||||
data = response.json()
|
||||
return AdenCredentialResponse.from_dict(data)
|
||||
|
||||
def list_integrations(self) -> list[AdenIntegrationInfo]:
|
||||
"""
|
||||
List all integrations available for this agent/tenant.
|
||||
|
||||
Returns:
|
||||
List of integration info objects.
|
||||
|
||||
Raises:
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
AdenClientError: For connection failures.
|
||||
"""
|
||||
response = self._request_with_retry("GET", "/v1/credentials")
|
||||
data = response.json()
|
||||
return [AdenIntegrationInfo.from_dict(item) for item in data.get("integrations", [])]
|
||||
|
||||
def validate_token(self, integration_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Check if a token is still valid without fetching it.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier.
|
||||
|
||||
Returns:
|
||||
Dict with 'valid' bool and optional 'expires_at', 'reason',
|
||||
'requires_reauthorization', 'reauthorization_url'.
|
||||
|
||||
Raises:
|
||||
AdenNotFoundError: If integration not found.
|
||||
AdenAuthenticationError: If API key is invalid.
|
||||
"""
|
||||
response = self._request_with_retry("GET", f"/v1/credentials/{integration_id}/validate")
|
||||
return response.json()
|
||||
|
||||
def report_usage(
|
||||
self,
|
||||
integration_id: str,
|
||||
operation: str,
|
||||
status: str = "success",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Report credential usage statistics to Aden.
|
||||
|
||||
This is optional and used for analytics/billing.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier.
|
||||
operation: Operation name (e.g., 'api_call').
|
||||
status: Operation status ('success', 'error').
|
||||
metadata: Additional operation metadata.
|
||||
"""
|
||||
try:
|
||||
self._request_with_retry(
|
||||
"POST",
|
||||
f"/v1/credentials/{integration_id}/usage",
|
||||
json={
|
||||
"operation": operation,
|
||||
"status": status,
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"metadata": metadata or {},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# Usage reporting is best-effort, don't fail on errors
|
||||
logger.warning(f"Failed to report usage for '{integration_id}': {e}")
|
||||
|
||||
def health_check(self) -> dict[str, Any]:
|
||||
"""
|
||||
Check Aden server health and connectivity.
|
||||
|
||||
Returns:
|
||||
Dict with 'status', 'version', 'timestamp', and optionally 'error'.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = client.get("/health")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
data["latency_ms"] = response.elapsed.total_seconds() * 1000
|
||||
return data
|
||||
return {
|
||||
"status": "degraded",
|
||||
"error": f"Unexpected status code: {response.status_code}",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the HTTP client and release resources."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def __enter__(self) -> AdenCredentialClient:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
@@ -0,0 +1,415 @@
|
||||
"""
|
||||
Aden Sync Provider.
|
||||
|
||||
Provider that synchronizes credentials with the Aden authentication server.
|
||||
The Aden server is the authoritative source for OAuth2 tokens - this provider
|
||||
fetches and caches tokens locally while delegating refresh operations to Aden.
|
||||
|
||||
Usage:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.storage import EncryptedFileStorage
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# Configure client (API key loaded from ADEN_API_KEY env var)
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url=os.environ["ADEN_API_URL"],
|
||||
))
|
||||
|
||||
# Create provider
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
# Create store
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage(),
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Initial sync from Aden
|
||||
provider.sync_all(store)
|
||||
|
||||
# Use normally - auto-refreshes via Aden when needed
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from ..models import CredentialKey, CredentialObject, CredentialRefreshError, CredentialType
|
||||
from ..provider import CredentialProvider
|
||||
from .client import (
|
||||
AdenClientError,
|
||||
AdenCredentialClient,
|
||||
AdenCredentialResponse,
|
||||
AdenRefreshError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..store import CredentialStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdenSyncProvider(CredentialProvider):
|
||||
"""
|
||||
Provider that synchronizes credentials with the Aden server.
|
||||
|
||||
The Aden server handles OAuth2 authorization flows and maintains
|
||||
refresh tokens. This provider:
|
||||
|
||||
- Fetches access tokens from the Aden server
|
||||
- Delegates token refresh to the Aden server
|
||||
- Caches tokens locally in the credential store
|
||||
- Optionally reports usage statistics back to Aden
|
||||
|
||||
Key benefits:
|
||||
- Client secrets never leave the Aden server
|
||||
- Refresh token security (stored only on Aden)
|
||||
- Centralized audit logging
|
||||
- Multi-tenant support
|
||||
|
||||
Usage:
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url="https://hive.adenhq.com",
|
||||
api_key=os.environ["ADEN_API_KEY"],
|
||||
))
|
||||
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage(),
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AdenCredentialClient,
|
||||
provider_id: str = "aden_sync",
|
||||
refresh_buffer_minutes: int = 5,
|
||||
report_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Aden sync provider.
|
||||
|
||||
Args:
|
||||
client: Configured Aden API client.
|
||||
provider_id: Unique identifier for this provider instance.
|
||||
Useful for multi-tenant scenarios (e.g., 'aden_tenant_123').
|
||||
refresh_buffer_minutes: Minutes before expiry to trigger refresh.
|
||||
Default is 5 minutes.
|
||||
report_usage: Whether to report usage statistics to Aden server.
|
||||
"""
|
||||
self._client = client
|
||||
self._provider_id = provider_id
|
||||
self._refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
|
||||
self._report_usage = report_usage
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
"""Unique identifier for this provider."""
|
||||
return self._provider_id
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
"""Credential types this provider can manage."""
|
||||
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
|
||||
|
||||
def can_handle(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if this provider can handle a credential.
|
||||
|
||||
Returns True if:
|
||||
- Credential type is supported (OAUTH2 or BEARER_TOKEN)
|
||||
- Credential's provider_id matches this provider, OR
|
||||
- Credential has '_aden_managed' metadata flag
|
||||
"""
|
||||
if credential.credential_type not in self.supported_types:
|
||||
return False
|
||||
|
||||
# Check if credential is explicitly linked to this provider
|
||||
if credential.provider_id == self.provider_id:
|
||||
return True
|
||||
|
||||
# Check for Aden-managed flag in metadata
|
||||
aden_flag = credential.keys.get("_aden_managed")
|
||||
if aden_flag and aden_flag.value.get_secret_value() == "true":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Refresh credential by requesting new token from Aden server.
|
||||
|
||||
The Aden server handles the actual OAuth2 refresh token flow.
|
||||
This method simply fetches the result.
|
||||
|
||||
Args:
|
||||
credential: The credential to refresh.
|
||||
|
||||
Returns:
|
||||
Updated credential with new access token.
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails.
|
||||
"""
|
||||
try:
|
||||
# Request Aden to refresh the token
|
||||
aden_response = self._client.request_refresh(credential.id)
|
||||
|
||||
# Update credential with new values
|
||||
credential = self._update_credential_from_aden(credential, aden_response)
|
||||
|
||||
logger.info(f"Refreshed credential '{credential.id}' via Aden server")
|
||||
|
||||
# Report usage if enabled
|
||||
if self._report_usage:
|
||||
self._client.report_usage(
|
||||
integration_id=credential.id,
|
||||
operation="token_refresh",
|
||||
status="success",
|
||||
)
|
||||
|
||||
return credential
|
||||
|
||||
except AdenRefreshError as e:
|
||||
logger.error(f"Aden refresh failed for '{credential.id}': {e}")
|
||||
|
||||
if e.requires_reauthorization:
|
||||
raise CredentialRefreshError(
|
||||
f"Integration '{credential.id}' requires re-authorization. "
|
||||
f"Visit: {e.reauthorization_url or 'your Aden dashboard'}"
|
||||
) from e
|
||||
|
||||
raise CredentialRefreshError(
|
||||
f"Failed to refresh credential '{credential.id}': {e}"
|
||||
) from e
|
||||
|
||||
except AdenClientError as e:
|
||||
logger.error(f"Aden client error for '{credential.id}': {e}")
|
||||
|
||||
# Check if local token is still valid
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key and access_key.expires_at:
|
||||
if datetime.now(UTC) < access_key.expires_at:
|
||||
logger.warning(f"Aden unavailable, using cached token for '{credential.id}'")
|
||||
return credential
|
||||
|
||||
raise CredentialRefreshError(
|
||||
f"Aden server unavailable and token expired for '{credential.id}'"
|
||||
) from e
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate credential via Aden server introspection.
|
||||
|
||||
Args:
|
||||
credential: The credential to validate.
|
||||
|
||||
Returns:
|
||||
True if credential is valid.
|
||||
"""
|
||||
try:
|
||||
result = self._client.validate_token(credential.id)
|
||||
return result.get("valid", False)
|
||||
except AdenClientError:
|
||||
# Fall back to local validation
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
if access_key.expires_at is None:
|
||||
# No expiration - assume valid
|
||||
return True
|
||||
|
||||
return datetime.now(UTC) < access_key.expires_at
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if credential should be refreshed.
|
||||
|
||||
Returns True if access_token is expired or within the refresh buffer.
|
||||
|
||||
Args:
|
||||
credential: The credential to check.
|
||||
|
||||
Returns:
|
||||
True if credential should be refreshed.
|
||||
"""
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
if access_key.expires_at is None:
|
||||
return False
|
||||
|
||||
# Refresh if within buffer of expiration
|
||||
return datetime.now(UTC) >= (access_key.expires_at - self._refresh_buffer)
|
||||
|
||||
def fetch_from_aden(self, integration_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Fetch credential directly from Aden server.
|
||||
|
||||
Use this for initial population or when local cache is missing.
|
||||
|
||||
Args:
|
||||
integration_id: The integration identifier (e.g., 'hubspot').
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
|
||||
Raises:
|
||||
AdenClientError: For connection failures.
|
||||
"""
|
||||
aden_response = self._client.get_credential(integration_id)
|
||||
if aden_response is None:
|
||||
return None
|
||||
|
||||
return self._aden_response_to_credential(aden_response)
|
||||
|
||||
def sync_all(self, store: CredentialStore) -> int:
|
||||
"""
|
||||
Sync all credentials from Aden server to local store.
|
||||
|
||||
Fetches the list of available integrations from Aden and
|
||||
populates the local credential store with current tokens.
|
||||
|
||||
Args:
|
||||
store: The credential store to populate.
|
||||
|
||||
Returns:
|
||||
Number of credentials synced.
|
||||
"""
|
||||
synced = 0
|
||||
|
||||
try:
|
||||
integrations = self._client.list_integrations()
|
||||
|
||||
for info in integrations:
|
||||
if info.status != "active":
|
||||
logger.warning(
|
||||
f"Skipping integration '{info.integration_id}': status={info.status}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
cred = self.fetch_from_aden(info.integration_id)
|
||||
if cred:
|
||||
store.save_credential(cred)
|
||||
synced += 1
|
||||
logger.info(f"Synced credential '{info.integration_id}' from Aden")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync '{info.integration_id}': {e}")
|
||||
|
||||
except AdenClientError as e:
|
||||
logger.error(f"Failed to list integrations from Aden: {e}")
|
||||
|
||||
return synced
|
||||
|
||||
def report_credential_usage(
|
||||
self,
|
||||
credential: CredentialObject,
|
||||
operation: str,
|
||||
status: str = "success",
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Report credential usage to Aden server.
|
||||
|
||||
Args:
|
||||
credential: The credential that was used.
|
||||
operation: Operation name (e.g., 'api_call').
|
||||
status: Operation status ('success', 'error').
|
||||
metadata: Additional metadata.
|
||||
"""
|
||||
if self._report_usage:
|
||||
self._client.report_usage(
|
||||
integration_id=credential.id,
|
||||
operation=operation,
|
||||
status=status,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def _update_credential_from_aden(
|
||||
self,
|
||||
credential: CredentialObject,
|
||||
aden_response: AdenCredentialResponse,
|
||||
) -> CredentialObject:
|
||||
"""Update credential object from Aden response."""
|
||||
# Update access token
|
||||
credential.keys["access_token"] = CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(aden_response.access_token),
|
||||
expires_at=aden_response.expires_at,
|
||||
)
|
||||
|
||||
# Update scopes if present
|
||||
if aden_response.scopes:
|
||||
credential.keys["scope"] = CredentialKey(
|
||||
name="scope",
|
||||
value=SecretStr(" ".join(aden_response.scopes)),
|
||||
)
|
||||
|
||||
# Mark as Aden-managed
|
||||
credential.keys["_aden_managed"] = CredentialKey(
|
||||
name="_aden_managed",
|
||||
value=SecretStr("true"),
|
||||
)
|
||||
|
||||
# Store integration type
|
||||
credential.keys["_integration_type"] = CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr(aden_response.integration_type),
|
||||
)
|
||||
|
||||
# Update timestamps
|
||||
credential.last_refreshed = datetime.now(UTC)
|
||||
credential.provider_id = self.provider_id
|
||||
|
||||
return credential
|
||||
|
||||
def _aden_response_to_credential(
|
||||
self,
|
||||
aden_response: AdenCredentialResponse,
|
||||
) -> CredentialObject:
|
||||
"""Convert Aden response to CredentialObject."""
|
||||
keys: dict[str, CredentialKey] = {
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(aden_response.access_token),
|
||||
expires_at=aden_response.expires_at,
|
||||
),
|
||||
"_aden_managed": CredentialKey(
|
||||
name="_aden_managed",
|
||||
value=SecretStr("true"),
|
||||
),
|
||||
"_integration_type": CredentialKey(
|
||||
name="_integration_type",
|
||||
value=SecretStr(aden_response.integration_type),
|
||||
),
|
||||
}
|
||||
|
||||
if aden_response.scopes:
|
||||
keys["scope"] = CredentialKey(
|
||||
name="scope",
|
||||
value=SecretStr(" ".join(aden_response.scopes)),
|
||||
)
|
||||
|
||||
return CredentialObject(
|
||||
id=aden_response.integration_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys=keys,
|
||||
provider_id=self.provider_id,
|
||||
auto_refresh=True,
|
||||
)
|
||||
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Aden Cached Storage.
|
||||
|
||||
Storage backend that combines local cache with Aden server fallback.
|
||||
Provides offline resilience by caching credentials locally while
|
||||
keeping them synchronized with the Aden server.
|
||||
|
||||
Usage:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.storage import EncryptedFileStorage
|
||||
from core.framework.credentials.aden import (
|
||||
AdenCredentialClient,
|
||||
AdenClientConfig,
|
||||
AdenSyncProvider,
|
||||
AdenCachedStorage,
|
||||
)
|
||||
|
||||
# Configure
|
||||
client = AdenCredentialClient(AdenClientConfig(
|
||||
base_url=os.environ["ADEN_API_URL"],
|
||||
api_key=os.environ["ADEN_API_KEY"],
|
||||
))
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
# Create cached storage
|
||||
storage = AdenCachedStorage(
|
||||
local_storage=EncryptedFileStorage(),
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300, # Re-check Aden every 5 minutes
|
||||
)
|
||||
|
||||
# Create store
|
||||
store = CredentialStore(
|
||||
storage=storage,
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Credentials automatically fetched from Aden on first access
|
||||
# Cached locally for 5 minutes
|
||||
# Falls back to cache if Aden is unreachable
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..storage import CredentialStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import CredentialObject
|
||||
from .provider import AdenSyncProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdenCachedStorage(CredentialStorage):
|
||||
"""
|
||||
Storage with local cache and Aden server fallback.
|
||||
|
||||
This storage provides:
|
||||
- **Reads**: Try local cache first, fallback to Aden if stale/missing
|
||||
- **Writes**: Always write to local cache
|
||||
- **Offline resilience**: Uses cached credentials when Aden is unreachable
|
||||
|
||||
The cache TTL determines how long to trust local credentials before
|
||||
checking with the Aden server for updates. This balances:
|
||||
- Performance (fewer network calls)
|
||||
- Freshness (tokens stay current)
|
||||
- Resilience (works during brief outages)
|
||||
|
||||
Usage:
|
||||
storage = AdenCachedStorage(
|
||||
local_storage=EncryptedFileStorage(),
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300, # 5 minutes
|
||||
)
|
||||
|
||||
store = CredentialStore(
|
||||
storage=storage,
|
||||
providers=[provider],
|
||||
)
|
||||
|
||||
# First access fetches from Aden
|
||||
# Subsequent accesses use cache until TTL expires
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_storage: CredentialStorage,
|
||||
aden_provider: AdenSyncProvider,
|
||||
cache_ttl_seconds: int = 300,
|
||||
prefer_local: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize Aden-cached storage.
|
||||
|
||||
Args:
|
||||
local_storage: Local storage backend for caching (e.g., EncryptedFileStorage).
|
||||
aden_provider: Provider for fetching from Aden server.
|
||||
cache_ttl_seconds: How long to trust local cache before checking Aden.
|
||||
Default is 300 seconds (5 minutes).
|
||||
prefer_local: If True, use local cache when available and fresh.
|
||||
If False, always check Aden first.
|
||||
"""
|
||||
self._local = local_storage
|
||||
self._aden_provider = aden_provider
|
||||
self._cache_ttl = timedelta(seconds=cache_ttl_seconds)
|
||||
self._prefer_local = prefer_local
|
||||
self._cache_timestamps: dict[str, datetime] = {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save credential to local cache.
|
||||
|
||||
Args:
|
||||
credential: The credential to save.
|
||||
"""
|
||||
self._local.save(credential)
|
||||
self._cache_timestamps[credential.id] = datetime.now(UTC)
|
||||
logger.debug(f"Cached credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load credential from cache, with Aden fallback.
|
||||
|
||||
The loading strategy depends on the `prefer_local` setting:
|
||||
|
||||
If prefer_local=True (default):
|
||||
1. Check if local cache exists and is fresh (within TTL)
|
||||
2. If fresh, return cached credential
|
||||
3. If stale or missing, fetch from Aden
|
||||
4. Update local cache with Aden response
|
||||
5. If Aden fails, fall back to stale cache
|
||||
|
||||
If prefer_local=False:
|
||||
1. Always try to fetch from Aden first
|
||||
2. Update local cache with response
|
||||
3. Fall back to local cache only if Aden fails
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise.
|
||||
"""
|
||||
local_cred = self._local.load(credential_id)
|
||||
|
||||
# If we prefer local and have a fresh cache, use it
|
||||
if self._prefer_local and local_cred and self._is_cache_fresh(credential_id):
|
||||
logger.debug(f"Using cached credential '{credential_id}'")
|
||||
return local_cred
|
||||
|
||||
# Try to fetch from Aden
|
||||
try:
|
||||
aden_cred = self._aden_provider.fetch_from_aden(credential_id)
|
||||
if aden_cred:
|
||||
# Update local cache
|
||||
self.save(aden_cred)
|
||||
logger.debug(f"Fetched credential '{credential_id}' from Aden")
|
||||
return aden_cred
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch '{credential_id}' from Aden: {e}")
|
||||
|
||||
# Fall back to local cache if Aden fails
|
||||
if local_cred:
|
||||
logger.info(f"Using stale cached credential '{credential_id}'")
|
||||
return local_cred
|
||||
|
||||
# Return local credential if it exists (may be None)
|
||||
return local_cred
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Delete credential from local cache.
|
||||
|
||||
Note: This does NOT delete the credential from the Aden server.
|
||||
It only removes the local cache entry.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
True if credential existed and was deleted.
|
||||
"""
|
||||
self._cache_timestamps.pop(credential_id, None)
|
||||
return self._local.delete(credential_id)
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""
|
||||
List credentials from local cache.
|
||||
|
||||
Returns:
|
||||
List of credential IDs in local cache.
|
||||
"""
|
||||
return self._local.list_all()
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if credential exists in local cache.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
True if credential exists locally.
|
||||
"""
|
||||
return self._local.exists(credential_id)
|
||||
|
||||
def _is_cache_fresh(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if local cache is still fresh (within TTL).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
|
||||
Returns:
|
||||
True if cache is fresh, False if stale or not cached.
|
||||
"""
|
||||
cached_at = self._cache_timestamps.get(credential_id)
|
||||
if cached_at is None:
|
||||
return False
|
||||
return datetime.now(UTC) - cached_at < self._cache_ttl
|
||||
|
||||
def invalidate_cache(self, credential_id: str) -> None:
|
||||
"""
|
||||
Invalidate cache for a specific credential.
|
||||
|
||||
The next load() call will fetch from Aden regardless of TTL.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier.
|
||||
"""
|
||||
self._cache_timestamps.pop(credential_id, None)
|
||||
logger.debug(f"Invalidated cache for '{credential_id}'")
|
||||
|
||||
def invalidate_all(self) -> None:
|
||||
"""Invalidate all cache entries."""
|
||||
self._cache_timestamps.clear()
|
||||
logger.debug("Invalidated all cache entries")
|
||||
|
||||
def sync_all_from_aden(self) -> int:
|
||||
"""
|
||||
Sync all credentials from Aden server to local cache.
|
||||
|
||||
Fetches the list of available integrations from Aden and
|
||||
updates the local cache with current tokens.
|
||||
|
||||
Returns:
|
||||
Number of credentials synced.
|
||||
"""
|
||||
synced = 0
|
||||
|
||||
try:
|
||||
integrations = self._aden_provider._client.list_integrations()
|
||||
|
||||
for info in integrations:
|
||||
if info.status != "active":
|
||||
logger.warning(
|
||||
f"Skipping integration '{info.integration_id}': status={info.status}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
cred = self._aden_provider.fetch_from_aden(info.integration_id)
|
||||
if cred:
|
||||
self.save(cred)
|
||||
synced += 1
|
||||
logger.info(f"Synced credential '{info.integration_id}' from Aden")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync '{info.integration_id}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list integrations from Aden: {e}")
|
||||
|
||||
return synced
|
||||
|
||||
def get_cache_info(self) -> dict[str, dict]:
|
||||
"""
|
||||
Get cache status information for all credentials.
|
||||
|
||||
Returns:
|
||||
Dict mapping credential_id to cache info (cached_at, is_fresh, ttl_remaining).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
info = {}
|
||||
|
||||
for cred_id in self.list_all():
|
||||
cached_at = self._cache_timestamps.get(cred_id)
|
||||
if cached_at:
|
||||
ttl_remaining = (cached_at + self._cache_ttl - now).total_seconds()
|
||||
info[cred_id] = {
|
||||
"cached_at": cached_at.isoformat(),
|
||||
"is_fresh": ttl_remaining > 0,
|
||||
"ttl_remaining_seconds": max(0, ttl_remaining),
|
||||
}
|
||||
else:
|
||||
info[cred_id] = {
|
||||
"cached_at": None,
|
||||
"is_fresh": False,
|
||||
"ttl_remaining_seconds": 0,
|
||||
}
|
||||
|
||||
return info
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for Aden credential sync components."""
|
||||
@@ -0,0 +1,670 @@
|
||||
"""
|
||||
Tests for Aden credential sync components.
|
||||
|
||||
Tests cover:
|
||||
- AdenCredentialClient: HTTP client for Aden API
|
||||
- AdenSyncProvider: Provider that syncs with Aden
|
||||
- AdenCachedStorage: Storage with local cache + Aden fallback
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from framework.credentials import (
|
||||
CredentialKey,
|
||||
CredentialObject,
|
||||
CredentialStore,
|
||||
CredentialType,
|
||||
InMemoryStorage,
|
||||
)
|
||||
from framework.credentials.aden import (
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenClientError,
|
||||
AdenCredentialClient,
|
||||
AdenCredentialResponse,
|
||||
AdenIntegrationInfo,
|
||||
AdenRefreshError,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aden_config():
|
||||
"""Create a test Aden client config."""
|
||||
return AdenClientConfig(
|
||||
base_url="https://api.test-aden.com",
|
||||
api_key="test-api-key",
|
||||
tenant_id="test-tenant",
|
||||
timeout=5.0,
|
||||
retry_attempts=2,
|
||||
retry_delay=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(aden_config):
|
||||
"""Create a mock Aden client."""
|
||||
client = Mock(spec=AdenCredentialClient)
|
||||
client.config = aden_config
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aden_response():
|
||||
"""Create a sample Aden credential response."""
|
||||
return AdenCredentialResponse(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
access_token="test-access-token",
|
||||
token_type="Bearer",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
scopes=["crm.objects.contacts.read", "crm.objects.contacts.write"],
|
||||
metadata={"portal_id": "12345"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(mock_client):
|
||||
"""Create an AdenSyncProvider with mock client."""
|
||||
return AdenSyncProvider(
|
||||
client=mock_client,
|
||||
provider_id="test_aden",
|
||||
refresh_buffer_minutes=5,
|
||||
report_usage=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_storage():
|
||||
"""Create an in-memory storage for testing."""
|
||||
return InMemoryStorage()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cached_storage(local_storage, provider):
|
||||
"""Create an AdenCachedStorage for testing."""
|
||||
return AdenCachedStorage(
|
||||
local_storage=local_storage,
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=60,
|
||||
prefer_local=True,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AdenCredentialResponse Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenCredentialResponse:
|
||||
"""Tests for AdenCredentialResponse dataclass."""
|
||||
|
||||
def test_from_dict_basic(self):
|
||||
"""Test creating response from dict."""
|
||||
data = {
|
||||
"integration_id": "github",
|
||||
"integration_type": "github",
|
||||
"access_token": "ghp_xxxxx",
|
||||
}
|
||||
|
||||
response = AdenCredentialResponse.from_dict(data)
|
||||
|
||||
assert response.integration_id == "github"
|
||||
assert response.integration_type == "github"
|
||||
assert response.access_token == "ghp_xxxxx"
|
||||
assert response.token_type == "Bearer"
|
||||
assert response.expires_at is None
|
||||
assert response.scopes == []
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""Test creating response with all fields."""
|
||||
data = {
|
||||
"integration_id": "hubspot",
|
||||
"integration_type": "hubspot",
|
||||
"access_token": "token123",
|
||||
"token_type": "Bearer",
|
||||
"expires_at": "2026-01-28T15:30:00Z",
|
||||
"scopes": ["read", "write"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
|
||||
response = AdenCredentialResponse.from_dict(data)
|
||||
|
||||
assert response.integration_id == "hubspot"
|
||||
assert response.access_token == "token123"
|
||||
assert response.expires_at is not None
|
||||
assert response.scopes == ["read", "write"]
|
||||
assert response.metadata == {"key": "value"}
|
||||
|
||||
|
||||
class TestAdenIntegrationInfo:
|
||||
"""Tests for AdenIntegrationInfo dataclass."""
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating integration info from dict."""
|
||||
data = {
|
||||
"integration_id": "slack",
|
||||
"integration_type": "slack",
|
||||
"status": "active",
|
||||
"expires_at": "2026-02-01T00:00:00Z",
|
||||
}
|
||||
|
||||
info = AdenIntegrationInfo.from_dict(data)
|
||||
|
||||
assert info.integration_id == "slack"
|
||||
assert info.integration_type == "slack"
|
||||
assert info.status == "active"
|
||||
assert info.expires_at is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AdenSyncProvider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenSyncProvider:
|
||||
"""Tests for AdenSyncProvider."""
|
||||
|
||||
def test_provider_id(self, provider):
|
||||
"""Test provider ID."""
|
||||
assert provider.provider_id == "test_aden"
|
||||
|
||||
def test_supported_types(self, provider):
|
||||
"""Test supported credential types."""
|
||||
assert CredentialType.OAUTH2 in provider.supported_types
|
||||
assert CredentialType.BEARER_TOKEN in provider.supported_types
|
||||
|
||||
def test_can_handle_oauth2(self, provider):
|
||||
"""Test can_handle returns True for OAUTH2 credentials with matching provider_id."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
provider_id="test_aden",
|
||||
)
|
||||
|
||||
assert provider.can_handle(cred) is True
|
||||
|
||||
def test_can_handle_aden_managed(self, provider):
|
||||
"""Test can_handle returns True for Aden-managed credentials."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"_aden_managed": CredentialKey(
|
||||
name="_aden_managed",
|
||||
value=SecretStr("true"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.can_handle(cred) is True
|
||||
|
||||
def test_can_handle_wrong_type(self, provider):
|
||||
"""Test can_handle returns False for unsupported types."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={},
|
||||
)
|
||||
|
||||
assert provider.can_handle(cred) is False
|
||||
|
||||
def test_refresh_success(self, provider, mock_client, aden_response):
|
||||
"""Test successful credential refresh."""
|
||||
mock_client.request_refresh.return_value = aden_response
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("old-token"),
|
||||
)
|
||||
},
|
||||
provider_id="test_aden",
|
||||
)
|
||||
|
||||
refreshed = provider.refresh(cred)
|
||||
|
||||
assert refreshed.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
assert refreshed.keys["_aden_managed"].value.get_secret_value() == "true"
|
||||
assert refreshed.last_refreshed is not None
|
||||
mock_client.request_refresh.assert_called_once_with("hubspot")
|
||||
|
||||
def test_refresh_requires_reauth(self, provider, mock_client):
|
||||
"""Test refresh that requires re-authorization."""
|
||||
mock_client.request_refresh.side_effect = AdenRefreshError(
|
||||
"Token revoked",
|
||||
requires_reauthorization=True,
|
||||
reauthorization_url="https://aden.com/reauth",
|
||||
)
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
|
||||
from framework.credentials import CredentialRefreshError
|
||||
|
||||
with pytest.raises(CredentialRefreshError) as exc_info:
|
||||
provider.refresh(cred)
|
||||
|
||||
assert "re-authorization" in str(exc_info.value).lower()
|
||||
|
||||
def test_refresh_aden_unavailable_cached_valid(self, provider, mock_client):
|
||||
"""Test refresh falls back to cache when Aden is unavailable and token is valid."""
|
||||
mock_client.request_refresh.side_effect = AdenClientError("Connection failed")
|
||||
|
||||
# Token expires in 1 hour - still valid
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("cached-token"),
|
||||
expires_at=future,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Should return the cached credential instead of failing
|
||||
result = provider.refresh(cred)
|
||||
|
||||
assert result.keys["access_token"].value.get_secret_value() == "cached-token"
|
||||
|
||||
def test_should_refresh_expired(self, provider):
|
||||
"""Test should_refresh returns True for expired token."""
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=past,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.should_refresh(cred) is True
|
||||
|
||||
def test_should_refresh_within_buffer(self, provider):
|
||||
"""Test should_refresh returns True when within buffer."""
|
||||
# Expires in 3 minutes (buffer is 5 minutes)
|
||||
soon = datetime.now(UTC) + timedelta(minutes=3)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=soon,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.should_refresh(cred) is True
|
||||
|
||||
def test_should_refresh_still_valid(self, provider):
|
||||
"""Test should_refresh returns False for valid token."""
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=future,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.should_refresh(cred) is False
|
||||
|
||||
def test_fetch_from_aden(self, provider, mock_client, aden_response):
|
||||
"""Test fetching credential from Aden."""
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
cred = provider.fetch_from_aden("hubspot")
|
||||
|
||||
assert cred is not None
|
||||
assert cred.id == "hubspot"
|
||||
assert cred.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
assert cred.auto_refresh is True
|
||||
|
||||
def test_fetch_from_aden_not_found(self, provider, mock_client):
|
||||
"""Test fetch returns None when not found."""
|
||||
mock_client.get_credential.return_value = None
|
||||
|
||||
cred = provider.fetch_from_aden("nonexistent")
|
||||
|
||||
assert cred is None
|
||||
|
||||
def test_sync_all(self, provider, mock_client, aden_response):
|
||||
"""Test syncing all credentials."""
|
||||
mock_client.list_integrations.return_value = [
|
||||
AdenIntegrationInfo(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
status="active",
|
||||
),
|
||||
AdenIntegrationInfo(
|
||||
integration_id="github",
|
||||
integration_type="github",
|
||||
status="requires_reauth", # Should be skipped
|
||||
),
|
||||
]
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
store = CredentialStore(storage=InMemoryStorage())
|
||||
synced = provider.sync_all(store)
|
||||
|
||||
assert synced == 1 # Only active one was synced
|
||||
assert store.get_credential("hubspot") is not None
|
||||
|
||||
def test_validate_via_aden(self, provider, mock_client):
|
||||
"""Test validation via Aden introspection."""
|
||||
mock_client.validate_token.return_value = {"valid": True}
|
||||
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
|
||||
assert provider.validate(cred) is True
|
||||
|
||||
def test_validate_fallback_to_local(self, provider, mock_client):
|
||||
"""Test validation falls back to local check when Aden fails."""
|
||||
mock_client.validate_token.side_effect = AdenClientError("Failed")
|
||||
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
expires_at=future,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.validate(cred) is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AdenCachedStorage Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenCachedStorage:
|
||||
"""Tests for AdenCachedStorage."""
|
||||
|
||||
def test_save_updates_cache_timestamp(self, cached_storage):
|
||||
"""Test save updates cache timestamp."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert "test" in cached_storage._cache_timestamps
|
||||
assert cached_storage.exists("test")
|
||||
|
||||
def test_load_from_fresh_cache(self, cached_storage, local_storage):
|
||||
"""Test load returns cached credential when fresh."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("cached-token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Save to both local storage and update timestamp
|
||||
local_storage.save(cred)
|
||||
cached_storage._cache_timestamps["test"] = datetime.now(UTC)
|
||||
|
||||
loaded = cached_storage.load("test")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "cached-token"
|
||||
|
||||
def test_load_from_aden_when_stale(
|
||||
self, cached_storage, local_storage, provider, mock_client, aden_response
|
||||
):
|
||||
"""Test load fetches from Aden when cache is stale."""
|
||||
# Create stale cached credential
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("stale-token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
|
||||
# Set cache timestamp to be stale (2 minutes ago, TTL is 60 seconds)
|
||||
cached_storage._cache_timestamps["hubspot"] = datetime.now(UTC) - timedelta(minutes=2)
|
||||
|
||||
# Mock Aden response
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
loaded = cached_storage.load("hubspot")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
|
||||
def test_load_falls_back_to_stale_when_aden_fails(
|
||||
self, cached_storage, local_storage, provider, mock_client
|
||||
):
|
||||
"""Test load falls back to stale cache when Aden fails."""
|
||||
# Create stale cached credential
|
||||
cred = CredentialObject(
|
||||
id="hubspot",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("stale-token"),
|
||||
)
|
||||
},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
cached_storage._cache_timestamps["hubspot"] = datetime.now(UTC) - timedelta(minutes=2)
|
||||
|
||||
# Aden fails
|
||||
mock_client.get_credential.side_effect = AdenClientError("Connection failed")
|
||||
|
||||
loaded = cached_storage.load("hubspot")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.keys["access_token"].value.get_secret_value() == "stale-token"
|
||||
|
||||
def test_delete_removes_cache_timestamp(self, cached_storage, local_storage):
|
||||
"""Test delete removes cache timestamp."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
cached_storage.save(cred)
|
||||
|
||||
assert "test" in cached_storage._cache_timestamps
|
||||
|
||||
cached_storage.delete("test")
|
||||
|
||||
assert "test" not in cached_storage._cache_timestamps
|
||||
assert not cached_storage.exists("test")
|
||||
|
||||
def test_invalidate_cache(self, cached_storage, local_storage):
|
||||
"""Test invalidate_cache removes timestamp."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
cached_storage.save(cred)
|
||||
|
||||
cached_storage.invalidate_cache("test")
|
||||
|
||||
assert "test" not in cached_storage._cache_timestamps
|
||||
# Credential still exists in local storage
|
||||
assert local_storage.exists("test")
|
||||
|
||||
def test_invalidate_all(self, cached_storage):
|
||||
"""Test invalidate_all clears all timestamps."""
|
||||
for i in range(3):
|
||||
cached_storage._cache_timestamps[f"test_{i}"] = datetime.now(UTC)
|
||||
|
||||
cached_storage.invalidate_all()
|
||||
|
||||
assert len(cached_storage._cache_timestamps) == 0
|
||||
|
||||
def test_is_cache_fresh(self, cached_storage):
|
||||
"""Test _is_cache_fresh logic."""
|
||||
# Fresh cache
|
||||
cached_storage._cache_timestamps["fresh"] = datetime.now(UTC)
|
||||
assert cached_storage._is_cache_fresh("fresh") is True
|
||||
|
||||
# Stale cache
|
||||
cached_storage._cache_timestamps["stale"] = datetime.now(UTC) - timedelta(minutes=5)
|
||||
assert cached_storage._is_cache_fresh("stale") is False
|
||||
|
||||
# No cache
|
||||
assert cached_storage._is_cache_fresh("nonexistent") is False
|
||||
|
||||
def test_get_cache_info(self, cached_storage, local_storage):
|
||||
"""Test get_cache_info returns status for all credentials."""
|
||||
# Add some credentials
|
||||
for name in ["fresh", "stale"]:
|
||||
cred = CredentialObject(
|
||||
id=name,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={},
|
||||
)
|
||||
local_storage.save(cred)
|
||||
|
||||
cached_storage._cache_timestamps["fresh"] = datetime.now(UTC)
|
||||
cached_storage._cache_timestamps["stale"] = datetime.now(UTC) - timedelta(minutes=5)
|
||||
|
||||
info = cached_storage.get_cache_info()
|
||||
|
||||
assert "fresh" in info
|
||||
assert info["fresh"]["is_fresh"] is True
|
||||
assert info["fresh"]["ttl_remaining_seconds"] > 0
|
||||
|
||||
assert "stale" in info
|
||||
assert info["stale"]["is_fresh"] is False
|
||||
assert info["stale"]["ttl_remaining_seconds"] == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdenIntegration:
|
||||
"""Integration tests for Aden sync components."""
|
||||
|
||||
def test_full_workflow(self, mock_client, aden_response):
|
||||
"""Test full workflow: sync, get, refresh."""
|
||||
# Setup
|
||||
mock_client.list_integrations.return_value = [
|
||||
AdenIntegrationInfo(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
status="active",
|
||||
),
|
||||
]
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
mock_client.request_refresh.return_value = AdenCredentialResponse(
|
||||
integration_id="hubspot",
|
||||
integration_type="hubspot",
|
||||
access_token="refreshed-token",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=2),
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
provider = AdenSyncProvider(client=mock_client)
|
||||
storage = InMemoryStorage()
|
||||
store = CredentialStore(
|
||||
storage=storage,
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
# Initial sync
|
||||
synced = provider.sync_all(store)
|
||||
assert synced == 1
|
||||
|
||||
# Get credential
|
||||
cred = store.get_credential("hubspot")
|
||||
assert cred is not None
|
||||
assert cred.keys["access_token"].value.get_secret_value() == "test-access-token"
|
||||
|
||||
# Simulate expiration
|
||||
cred.keys["access_token"] = CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr("test-access-token"),
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1), # Expired
|
||||
)
|
||||
storage.save(cred)
|
||||
|
||||
# Refresh should be triggered
|
||||
refreshed = provider.refresh(cred)
|
||||
assert refreshed.keys["access_token"].value.get_secret_value() == "refreshed-token"
|
||||
|
||||
def test_cached_storage_with_store(self, mock_client, aden_response):
|
||||
"""Test AdenCachedStorage with CredentialStore."""
|
||||
mock_client.get_credential.return_value = aden_response
|
||||
|
||||
provider = AdenSyncProvider(client=mock_client)
|
||||
local_storage = InMemoryStorage()
|
||||
cached_storage = AdenCachedStorage(
|
||||
local_storage=local_storage,
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300,
|
||||
)
|
||||
|
||||
# First load fetches from Aden
|
||||
cred = cached_storage.load("hubspot")
|
||||
assert cred is not None
|
||||
mock_client.get_credential.assert_called_once()
|
||||
|
||||
# Second load uses cache
|
||||
mock_client.get_credential.reset_mock()
|
||||
cred2 = cached_storage.load("hubspot")
|
||||
assert cred2 is not None
|
||||
mock_client.get_credential.assert_not_called()
|
||||
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
Core data models for the credential store.
|
||||
|
||||
This module defines the key-vault structure where credentials are objects
|
||||
containing one or more keys (e.g., api_key, access_token, refresh_token).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class CredentialType(str, Enum):
|
||||
"""Types of credentials the store can manage."""
|
||||
|
||||
API_KEY = "api_key"
|
||||
"""Simple API key (e.g., Brave Search, OpenAI)"""
|
||||
|
||||
OAUTH2 = "oauth2"
|
||||
"""OAuth2 with refresh token support"""
|
||||
|
||||
BASIC_AUTH = "basic_auth"
|
||||
"""Username/password pair"""
|
||||
|
||||
BEARER_TOKEN = "bearer_token"
|
||||
"""JWT or bearer token without refresh"""
|
||||
|
||||
CUSTOM = "custom"
|
||||
"""User-defined credential type"""
|
||||
|
||||
|
||||
class CredentialKey(BaseModel):
|
||||
"""
|
||||
A single key within a credential object.
|
||||
|
||||
Example: 'api_key' within a 'brave_search' credential
|
||||
|
||||
Attributes:
|
||||
name: Key name (e.g., 'api_key', 'access_token')
|
||||
value: Secret value (SecretStr prevents accidental logging)
|
||||
expires_at: Optional expiration time
|
||||
metadata: Additional key-specific metadata
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: SecretStr
|
||||
expires_at: datetime | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this key has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now(UTC) >= self.expires_at
|
||||
|
||||
def get_secret_value(self) -> str:
|
||||
"""Get the actual secret value (use sparingly)."""
|
||||
return self.value.get_secret_value()
|
||||
|
||||
|
||||
class CredentialObject(BaseModel):
|
||||
"""
|
||||
A credential object containing one or more keys.
|
||||
|
||||
This is the key-vault structure where each credential can have
|
||||
multiple keys (e.g., access_token, refresh_token, expires_at).
|
||||
|
||||
Example:
|
||||
CredentialObject(
|
||||
id="github_oauth",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(name="access_token", value=SecretStr("ghp_xxx")),
|
||||
"refresh_token": CredentialKey(name="refresh_token", value=SecretStr("ghr_xxx")),
|
||||
},
|
||||
provider_id="oauth2"
|
||||
)
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier (e.g., 'brave_search', 'github_oauth')
|
||||
credential_type: Type of credential (API_KEY, OAUTH2, etc.)
|
||||
keys: Dictionary of key name to CredentialKey
|
||||
provider_id: ID of provider responsible for lifecycle management
|
||||
auto_refresh: Whether to automatically refresh when expired
|
||||
"""
|
||||
|
||||
id: str = Field(description="Unique identifier (e.g., 'brave_search', 'github_oauth')")
|
||||
credential_type: CredentialType = CredentialType.API_KEY
|
||||
keys: dict[str, CredentialKey] = Field(default_factory=dict)
|
||||
|
||||
# Lifecycle management
|
||||
provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of provider responsible for lifecycle (e.g., 'oauth2', 'static')",
|
||||
)
|
||||
last_refreshed: datetime | None = None
|
||||
auto_refresh: bool = False
|
||||
|
||||
# Usage tracking
|
||||
last_used: datetime | None = None
|
||||
use_count: int = 0
|
||||
|
||||
# Metadata
|
||||
description: str = ""
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
updated_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def get_key(self, key_name: str) -> str | None:
|
||||
"""
|
||||
Get a specific key's value.
|
||||
|
||||
Args:
|
||||
key_name: Name of the key to retrieve
|
||||
|
||||
Returns:
|
||||
The key's secret value, or None if not found
|
||||
"""
|
||||
key = self.keys.get(key_name)
|
||||
if key is None:
|
||||
return None
|
||||
return key.get_secret_value()
|
||||
|
||||
def set_key(
|
||||
self,
|
||||
key_name: str,
|
||||
value: str,
|
||||
expires_at: datetime | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set or update a key.
|
||||
|
||||
Args:
|
||||
key_name: Name of the key
|
||||
value: Secret value
|
||||
expires_at: Optional expiration time
|
||||
metadata: Optional key-specific metadata
|
||||
"""
|
||||
self.keys[key_name] = CredentialKey(
|
||||
name=key_name,
|
||||
value=SecretStr(value),
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.updated_at = datetime.now(UTC)
|
||||
|
||||
def has_key(self, key_name: str) -> bool:
|
||||
"""Check if a key exists."""
|
||||
return key_name in self.keys
|
||||
|
||||
@property
|
||||
def needs_refresh(self) -> bool:
|
||||
"""Check if any key is expired or near expiration."""
|
||||
for key in self.keys.values():
|
||||
if key.is_expired:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if credential has at least one non-expired key."""
|
||||
if not self.keys:
|
||||
return False
|
||||
return not all(key.is_expired for key in self.keys.values())
|
||||
|
||||
def record_usage(self) -> None:
|
||||
"""Record that this credential was used."""
|
||||
self.last_used = datetime.now(UTC)
|
||||
self.use_count += 1
|
||||
|
||||
def get_default_key(self) -> str | None:
|
||||
"""
|
||||
Get the default key value.
|
||||
|
||||
Priority: 'value' > 'api_key' > 'access_token' > first key
|
||||
|
||||
Returns:
|
||||
The default key's value, or None if no keys exist
|
||||
"""
|
||||
for key_name in ["value", "api_key", "access_token"]:
|
||||
if key_name in self.keys:
|
||||
return self.get_key(key_name)
|
||||
|
||||
if self.keys:
|
||||
first_key = next(iter(self.keys))
|
||||
return self.get_key(first_key)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class CredentialUsageSpec(BaseModel):
|
||||
"""
|
||||
Specification for how a tool uses credentials.
|
||||
|
||||
This implements the "bipartisan" model where the credential store
|
||||
just stores values, and tools define how those values are used
|
||||
in HTTP requests (headers, query params, body).
|
||||
|
||||
Example:
|
||||
CredentialUsageSpec(
|
||||
credential_id="brave_search",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Subscription-Token": "{{api_key}}"}
|
||||
)
|
||||
|
||||
CredentialUsageSpec(
|
||||
credential_id="github_oauth",
|
||||
required_keys=["access_token"],
|
||||
headers={"Authorization": "Bearer {{access_token}}"}
|
||||
)
|
||||
|
||||
Attributes:
|
||||
credential_id: ID of credential to use
|
||||
required_keys: Keys that must be present
|
||||
headers: Header templates with {{key}} placeholders
|
||||
query_params: Query parameter templates
|
||||
body_fields: Request body field templates
|
||||
"""
|
||||
|
||||
credential_id: str = Field(description="ID of credential to use (e.g., 'brave_search')")
|
||||
required_keys: list[str] = Field(default_factory=list, description="Keys that must be present")
|
||||
|
||||
# Injection templates (bipartisan model)
|
||||
headers: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Header templates (e.g., {'Authorization': 'Bearer {{access_token}}'})",
|
||||
)
|
||||
query_params: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Query param templates (e.g., {'api_key': '{{api_key}}'})",
|
||||
)
|
||||
body_fields: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Request body field templates",
|
||||
)
|
||||
|
||||
# Metadata
|
||||
required: bool = True
|
||||
description: str = ""
|
||||
help_url: str = ""
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class CredentialError(Exception):
|
||||
"""Base exception for credential-related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialNotFoundError(CredentialError):
|
||||
"""Raised when a referenced credential doesn't exist."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialKeyNotFoundError(CredentialError):
|
||||
"""Raised when a referenced key doesn't exist in a credential."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialRefreshError(CredentialError):
|
||||
"""Raised when credential refresh fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialValidationError(CredentialError):
|
||||
"""Raised when credential validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CredentialDecryptionError(CredentialError):
|
||||
"""Raised when credential decryption fails."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
OAuth2 support for the credential store.
|
||||
|
||||
This module provides OAuth2 credential management with:
|
||||
- Token types and configuration (OAuth2Token, OAuth2Config)
|
||||
- Generic OAuth2 provider (BaseOAuth2Provider)
|
||||
- Token lifecycle management (TokenLifecycleManager)
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config
|
||||
|
||||
# Configure OAuth2 provider
|
||||
provider = BaseOAuth2Provider(OAuth2Config(
|
||||
token_url="https://oauth2.example.com/token",
|
||||
client_id="your-client-id",
|
||||
client_secret="your-client-secret",
|
||||
default_scopes=["read", "write"],
|
||||
))
|
||||
|
||||
# Create store with OAuth2 provider
|
||||
store = CredentialStore.with_encrypted_storage(
|
||||
"/var/hive/credentials",
|
||||
providers=[provider]
|
||||
)
|
||||
|
||||
# Get token using client credentials
|
||||
token = provider.client_credentials_grant()
|
||||
|
||||
# Save to store
|
||||
from core.framework.credentials import CredentialObject, CredentialKey, CredentialType
|
||||
from pydantic import SecretStr
|
||||
|
||||
store.save_credential(CredentialObject(
|
||||
id="my_api",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(token.access_token),
|
||||
expires_at=token.expires_at,
|
||||
),
|
||||
"refresh_token": CredentialKey(
|
||||
name="refresh_token",
|
||||
value=SecretStr(token.refresh_token),
|
||||
) if token.refresh_token else None,
|
||||
},
|
||||
provider_id="oauth2",
|
||||
auto_refresh=True,
|
||||
))
|
||||
|
||||
For advanced lifecycle management:
|
||||
from core.framework.credentials.oauth2 import TokenLifecycleManager
|
||||
|
||||
manager = TokenLifecycleManager(
|
||||
provider=provider,
|
||||
credential_id="my_api",
|
||||
store=store,
|
||||
)
|
||||
|
||||
# Get valid token (auto-refreshes if needed)
|
||||
token = manager.sync_get_valid_token()
|
||||
headers = manager.get_request_headers()
|
||||
"""
|
||||
|
||||
from .base_provider import BaseOAuth2Provider
|
||||
from .lifecycle import TokenLifecycleManager, TokenRefreshResult
|
||||
from .provider import (
|
||||
OAuth2Config,
|
||||
OAuth2Error,
|
||||
OAuth2Token,
|
||||
RefreshTokenInvalidError,
|
||||
TokenExpiredError,
|
||||
TokenPlacement,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"OAuth2Token",
|
||||
"OAuth2Config",
|
||||
"TokenPlacement",
|
||||
# Provider
|
||||
"BaseOAuth2Provider",
|
||||
# Lifecycle
|
||||
"TokenLifecycleManager",
|
||||
"TokenRefreshResult",
|
||||
# Errors
|
||||
"OAuth2Error",
|
||||
"TokenExpiredError",
|
||||
"RefreshTokenInvalidError",
|
||||
]
|
||||
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Base OAuth2 provider implementation.
|
||||
|
||||
This module provides a generic OAuth2 provider that works with standard
|
||||
OAuth2 servers. OSS users can extend this class for custom providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from ..models import CredentialObject, CredentialRefreshError, CredentialType
|
||||
from ..provider import CredentialProvider
|
||||
from .provider import (
|
||||
OAuth2Config,
|
||||
OAuth2Error,
|
||||
OAuth2Token,
|
||||
TokenPlacement,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOAuth2Provider(CredentialProvider):
|
||||
"""
|
||||
Generic OAuth2 provider implementation.
|
||||
|
||||
Works with standard OAuth2 servers (RFC 6749). Override methods for
|
||||
provider-specific behavior.
|
||||
|
||||
Supported grant types:
|
||||
- Client Credentials: For server-to-server authentication
|
||||
- Refresh Token: For refreshing expired access tokens
|
||||
- Authorization Code: For user-authorized access (requires callback handling)
|
||||
|
||||
OSS users can extend this class for custom providers:
|
||||
|
||||
class GitHubOAuth2Provider(BaseOAuth2Provider):
|
||||
def __init__(self, client_id: str, client_secret: str):
|
||||
super().__init__(OAuth2Config(
|
||||
token_url="https://github.com/login/oauth/access_token",
|
||||
authorization_url="https://github.com/login/oauth/authorize",
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
default_scopes=["repo", "user"],
|
||||
))
|
||||
|
||||
def exchange_code(self, code: str, redirect_uri: str, **kwargs) -> OAuth2Token:
|
||||
# GitHub returns data as form-encoded by default
|
||||
# Override to handle this
|
||||
...
|
||||
|
||||
Example usage:
|
||||
provider = BaseOAuth2Provider(OAuth2Config(
|
||||
token_url="https://oauth2.example.com/token",
|
||||
client_id="my-client-id",
|
||||
client_secret="my-client-secret",
|
||||
))
|
||||
|
||||
# Get token using client credentials
|
||||
token = provider.client_credentials_grant()
|
||||
|
||||
# Refresh an expired token
|
||||
new_token = provider.refresh_token(old_token.refresh_token)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2Config, provider_id: str = "oauth2"):
|
||||
"""
|
||||
Initialize the OAuth2 provider.
|
||||
|
||||
Args:
|
||||
config: OAuth2 configuration
|
||||
provider_id: Unique identifier for this provider instance
|
||||
"""
|
||||
self.config = config
|
||||
self._provider_id = provider_id
|
||||
self._client: Any | None = None
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return self._provider_id
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import httpx
|
||||
|
||||
self._client = httpx.Client(timeout=self.config.request_timeout)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"OAuth2 provider requires 'httpx'. Install with: pip install httpx"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
def _close_client(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup HTTP client on deletion."""
|
||||
self._close_client()
|
||||
|
||||
# --- Grant Types ---
|
||||
|
||||
def get_authorization_url(
|
||||
self,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Generate authorization URL for user consent (Authorization Code flow).
|
||||
|
||||
Args:
|
||||
state: Anti-CSRF state parameter (should be random and verified)
|
||||
redirect_uri: Callback URL to receive the authorization code
|
||||
scopes: Requested scopes (defaults to config.default_scopes)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
URL to redirect user for authorization
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization_url is not configured
|
||||
"""
|
||||
if not self.config.authorization_url:
|
||||
raise ValueError("authorization_url not configured for this provider")
|
||||
|
||||
params = {
|
||||
"client_id": self.config.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
"scope": " ".join(scopes or self.config.default_scopes),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return f"{self.config.authorization_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Exchange authorization code for tokens (Authorization Code flow).
|
||||
|
||||
Args:
|
||||
code: Authorization code from callback
|
||||
redirect_uri: Same redirect_uri used in authorization request
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
OAuth2Token with access_token and optional refresh_token
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If token exchange fails
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
**self.config.extra_token_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return self._token_request(data)
|
||||
|
||||
def client_credentials_grant(
|
||||
self,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Obtain token using client credentials (Client Credentials flow).
|
||||
|
||||
This is for server-to-server authentication where no user is involved.
|
||||
|
||||
Args:
|
||||
scopes: Requested scopes (defaults to config.default_scopes)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
OAuth2Token (typically without refresh_token)
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If token request fails
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
**self.config.extra_token_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if scopes or self.config.default_scopes:
|
||||
data["scope"] = " ".join(scopes or self.config.default_scopes)
|
||||
|
||||
return self._token_request(data)
|
||||
|
||||
def refresh_access_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Refresh an expired access token (Refresh Token flow).
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token
|
||||
scopes: Scopes to request (defaults to original scopes)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
New OAuth2Token (may include new refresh_token)
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If refresh fails
|
||||
RefreshTokenInvalidError: If refresh token is revoked/invalid
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
**self.config.extra_token_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if scopes:
|
||||
data["scope"] = " ".join(scopes)
|
||||
|
||||
return self._token_request(data)
|
||||
|
||||
def revoke_token(
|
||||
self,
|
||||
token: str,
|
||||
token_type_hint: str = "access_token",
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke a token (RFC 7009).
|
||||
|
||||
Args:
|
||||
token: The token to revoke
|
||||
token_type_hint: "access_token" or "refresh_token"
|
||||
|
||||
Returns:
|
||||
True if revocation succeeded
|
||||
"""
|
||||
if not self.config.revocation_url:
|
||||
logger.warning("revocation_url not configured, cannot revoke token")
|
||||
return False
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = client.post(
|
||||
self.config.revocation_url,
|
||||
data={
|
||||
"token": token,
|
||||
"token_type_hint": token_type_hint,
|
||||
"client_id": self.config.client_id,
|
||||
"client_secret": self.config.client_secret,
|
||||
},
|
||||
headers={"Accept": "application/json", **self.config.extra_headers},
|
||||
)
|
||||
# RFC 7009: 200 indicates success (even if token was already invalid)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Token revocation failed: {e}")
|
||||
return False
|
||||
|
||||
# --- CredentialProvider Interface ---
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Refresh a credential using its refresh token.
|
||||
|
||||
Implements CredentialProvider.refresh().
|
||||
|
||||
Args:
|
||||
credential: The credential to refresh
|
||||
|
||||
Returns:
|
||||
Updated credential with new access_token
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails
|
||||
"""
|
||||
refresh_tok = credential.get_key("refresh_token")
|
||||
if not refresh_tok:
|
||||
raise CredentialRefreshError(f"Credential '{credential.id}' has no refresh_token")
|
||||
|
||||
try:
|
||||
new_token = self.refresh_access_token(refresh_tok)
|
||||
except OAuth2Error as e:
|
||||
if e.error == "invalid_grant":
|
||||
raise CredentialRefreshError(
|
||||
f"Refresh token for '{credential.id}' is invalid or revoked. "
|
||||
"Re-authorization required."
|
||||
) from e
|
||||
raise CredentialRefreshError(f"Failed to refresh '{credential.id}': {e}") from e
|
||||
|
||||
# Update credential
|
||||
credential.set_key("access_token", new_token.access_token, expires_at=new_token.expires_at)
|
||||
|
||||
# Update refresh token if a new one was issued
|
||||
if new_token.refresh_token and new_token.refresh_token != refresh_tok:
|
||||
credential.set_key("refresh_token", new_token.refresh_token)
|
||||
|
||||
credential.last_refreshed = datetime.now(UTC)
|
||||
logger.info(f"Refreshed OAuth2 credential '{credential.id}'")
|
||||
|
||||
return credential
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate that credential has a valid (non-expired) access_token.
|
||||
|
||||
Args:
|
||||
credential: The credential to validate
|
||||
|
||||
Returns:
|
||||
True if credential has valid access_token
|
||||
"""
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
return not access_key.is_expired
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if credential should be refreshed.
|
||||
|
||||
Returns True if access_token is expired or within 5 minutes of expiry.
|
||||
"""
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
if access_key.expires_at is None:
|
||||
return False
|
||||
|
||||
buffer = timedelta(minutes=5)
|
||||
return datetime.now(UTC) >= (access_key.expires_at - buffer)
|
||||
|
||||
def revoke(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Revoke all tokens in a credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to revoke
|
||||
|
||||
Returns:
|
||||
True if all revocations succeeded
|
||||
"""
|
||||
success = True
|
||||
|
||||
# Revoke access token
|
||||
access_token = credential.get_key("access_token")
|
||||
if access_token:
|
||||
if not self.revoke_token(access_token, "access_token"):
|
||||
success = False
|
||||
|
||||
# Revoke refresh token
|
||||
refresh_token = credential.get_key("refresh_token")
|
||||
if refresh_token:
|
||||
if not self.revoke_token(refresh_token, "refresh_token"):
|
||||
success = False
|
||||
|
||||
return success
|
||||
|
||||
# --- Token Request Helpers ---
|
||||
|
||||
def _token_request(self, data: dict[str, Any]) -> OAuth2Token:
|
||||
"""
|
||||
Make a token request to the OAuth2 server.
|
||||
|
||||
Args:
|
||||
data: Form data for the token request
|
||||
|
||||
Returns:
|
||||
OAuth2Token from the response
|
||||
|
||||
Raises:
|
||||
OAuth2Error: If request fails or returns an error
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
**self.config.extra_headers,
|
||||
}
|
||||
|
||||
response = client.post(self.config.token_url, data=data, headers=headers)
|
||||
|
||||
# Parse response
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
response_data = response.json()
|
||||
else:
|
||||
# Some providers (like GitHub) may return form-encoded
|
||||
response_data = self._parse_form_response(response.text)
|
||||
|
||||
# Check for error
|
||||
if response.status_code != 200 or "error" in response_data:
|
||||
error = response_data.get("error", "unknown_error")
|
||||
description = response_data.get("error_description", response.text)
|
||||
raise OAuth2Error(
|
||||
error=error, description=description, status_code=response.status_code
|
||||
)
|
||||
|
||||
return OAuth2Token.from_token_response(response_data)
|
||||
|
||||
def _parse_form_response(self, text: str) -> dict[str, str]:
|
||||
"""Parse form-encoded response (some providers use this instead of JSON)."""
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
parsed = parse_qs(text)
|
||||
return {k: v[0] if len(v) == 1 else v for k, v in parsed.items()}
|
||||
|
||||
# --- Token Formatting for Requests ---
|
||||
|
||||
def format_for_request(self, token: OAuth2Token) -> dict[str, Any]:
|
||||
"""
|
||||
Format token for use in HTTP requests (bipartisan model).
|
||||
|
||||
Args:
|
||||
token: The OAuth2 token
|
||||
|
||||
Returns:
|
||||
Dict with 'headers', 'params', or 'data' keys as appropriate
|
||||
"""
|
||||
placement = self.config.token_placement
|
||||
|
||||
if placement == TokenPlacement.HEADER_BEARER:
|
||||
return {"headers": {"Authorization": f"{token.token_type} {token.access_token}"}}
|
||||
|
||||
elif placement == TokenPlacement.HEADER_CUSTOM:
|
||||
header_name = self.config.custom_header_name or "X-Access-Token"
|
||||
return {"headers": {header_name: token.access_token}}
|
||||
|
||||
elif placement == TokenPlacement.QUERY_PARAM:
|
||||
return {"params": {self.config.query_param_name: token.access_token}}
|
||||
|
||||
elif placement == TokenPlacement.BODY_PARAM:
|
||||
return {"data": {"access_token": token.access_token}}
|
||||
|
||||
return {}
|
||||
|
||||
def format_credential_for_request(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""
|
||||
Format a credential for use in HTTP requests.
|
||||
|
||||
Args:
|
||||
credential: The credential containing access_token
|
||||
|
||||
Returns:
|
||||
Dict with 'headers', 'params', or 'data' keys as appropriate
|
||||
"""
|
||||
access_token = credential.get_key("access_token")
|
||||
if not access_token:
|
||||
return {}
|
||||
|
||||
token = OAuth2Token(
|
||||
access_token=access_token,
|
||||
token_type=credential.keys.get("token_type", "Bearer") or "Bearer",
|
||||
)
|
||||
|
||||
return self.format_for_request(token)
|
||||
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Token lifecycle management for OAuth2 credentials.
|
||||
|
||||
This module provides the TokenLifecycleManager which coordinates
|
||||
automatic token refresh with the credential store.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from ..models import CredentialKey, CredentialObject, CredentialType
|
||||
from .base_provider import BaseOAuth2Provider
|
||||
from .provider import OAuth2Token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..store import CredentialStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenRefreshResult:
|
||||
"""Result of a token refresh operation."""
|
||||
|
||||
success: bool
|
||||
token: OAuth2Token | None = None
|
||||
error: str | None = None
|
||||
needs_reauthorization: bool = False
|
||||
|
||||
|
||||
class TokenLifecycleManager:
|
||||
"""
|
||||
Manages the complete lifecycle of OAuth2 tokens.
|
||||
|
||||
Responsibilities:
|
||||
- Coordinate with CredentialStore for persistence
|
||||
- Automatically refresh expired tokens
|
||||
- Handle refresh failures gracefully
|
||||
- Provide callbacks for monitoring
|
||||
|
||||
This class is useful when you need more control over token management
|
||||
than the basic auto-refresh in CredentialStore provides.
|
||||
|
||||
Usage:
|
||||
manager = TokenLifecycleManager(
|
||||
provider=github_provider,
|
||||
credential_id="github_oauth",
|
||||
store=credential_store,
|
||||
)
|
||||
|
||||
# Get valid token (auto-refreshes if needed)
|
||||
token = await manager.get_valid_token()
|
||||
|
||||
# Use token
|
||||
headers = provider.format_for_request(token)
|
||||
|
||||
Synchronous usage:
|
||||
# For synchronous code, use sync_ methods
|
||||
token = manager.sync_get_valid_token()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: BaseOAuth2Provider,
|
||||
credential_id: str,
|
||||
store: CredentialStore,
|
||||
refresh_buffer_minutes: int = 5,
|
||||
on_token_refreshed: Callable[[OAuth2Token], None] | None = None,
|
||||
on_refresh_failed: Callable[[str], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
|
||||
Args:
|
||||
provider: OAuth2 provider for token operations
|
||||
credential_id: ID of the credential in the store
|
||||
store: Credential store for persistence
|
||||
refresh_buffer_minutes: Minutes before expiry to trigger refresh
|
||||
on_token_refreshed: Callback when token is refreshed
|
||||
on_refresh_failed: Callback when refresh fails
|
||||
"""
|
||||
self.provider = provider
|
||||
self.credential_id = credential_id
|
||||
self.store = store
|
||||
self.refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
|
||||
self.on_token_refreshed = on_token_refreshed
|
||||
self.on_refresh_failed = on_refresh_failed
|
||||
|
||||
# In-memory cache for performance
|
||||
self._cached_token: OAuth2Token | None = None
|
||||
self._cache_time: datetime | None = None
|
||||
|
||||
# --- Async Token Access ---
|
||||
|
||||
async def get_valid_token(self) -> OAuth2Token | None:
|
||||
"""
|
||||
Get a valid access token, refreshing if necessary.
|
||||
|
||||
This is the main entry point for async code.
|
||||
|
||||
Returns:
|
||||
Valid OAuth2Token or None if unavailable
|
||||
"""
|
||||
# Check cache first
|
||||
if self._cached_token and not self._needs_refresh(self._cached_token):
|
||||
return self._cached_token
|
||||
|
||||
# Load from store
|
||||
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
# Convert to OAuth2Token
|
||||
token = self._credential_to_token(credential)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
# Refresh if needed
|
||||
if self._needs_refresh(token):
|
||||
result = await self._async_refresh_token(credential)
|
||||
if result.success and result.token:
|
||||
token = result.token
|
||||
elif result.needs_reauthorization:
|
||||
logger.warning(f"Token for {self.credential_id} needs reauthorization")
|
||||
return None
|
||||
else:
|
||||
# Use existing token if still technically valid
|
||||
if token.is_expired:
|
||||
return None
|
||||
logger.warning(f"Refresh failed for {self.credential_id}, using existing token")
|
||||
|
||||
self._cached_token = token
|
||||
self._cache_time = datetime.now(UTC)
|
||||
return token
|
||||
|
||||
async def acquire_token_client_credentials(
|
||||
self,
|
||||
scopes: list[str] | None = None,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Acquire a new token using client credentials flow.
|
||||
|
||||
For service-to-service authentication.
|
||||
|
||||
Args:
|
||||
scopes: Scopes to request
|
||||
|
||||
Returns:
|
||||
New OAuth2Token
|
||||
"""
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
token = await loop.run_in_executor(
|
||||
None, lambda: self.provider.client_credentials_grant(scopes=scopes)
|
||||
)
|
||||
|
||||
self._save_token_to_store(token)
|
||||
self._cached_token = token
|
||||
return token
|
||||
|
||||
async def revoke(self) -> bool:
|
||||
"""
|
||||
Revoke tokens and clear from store.
|
||||
|
||||
Returns:
|
||||
True if revocation succeeded
|
||||
"""
|
||||
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
||||
if credential:
|
||||
self.provider.revoke(credential)
|
||||
|
||||
self.store.delete_credential(self.credential_id)
|
||||
self._cached_token = None
|
||||
return True
|
||||
|
||||
# --- Synchronous Token Access ---
|
||||
|
||||
def sync_get_valid_token(self) -> OAuth2Token | None:
|
||||
"""
|
||||
Synchronous version of get_valid_token().
|
||||
|
||||
For use in synchronous code.
|
||||
"""
|
||||
# Check cache
|
||||
if self._cached_token and not self._needs_refresh(self._cached_token):
|
||||
return self._cached_token
|
||||
|
||||
# Load from store
|
||||
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
token = self._credential_to_token(credential)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
# Refresh if needed
|
||||
if self._needs_refresh(token):
|
||||
result = self._sync_refresh_token(credential)
|
||||
if result.success and result.token:
|
||||
token = result.token
|
||||
elif result.needs_reauthorization:
|
||||
logger.warning(f"Token for {self.credential_id} needs reauthorization")
|
||||
return None
|
||||
else:
|
||||
if token.is_expired:
|
||||
return None
|
||||
|
||||
self._cached_token = token
|
||||
self._cache_time = datetime.now(UTC)
|
||||
return token
|
||||
|
||||
def sync_acquire_token_client_credentials(
|
||||
self,
|
||||
scopes: list[str] | None = None,
|
||||
) -> OAuth2Token:
|
||||
"""Synchronous version of acquire_token_client_credentials()."""
|
||||
token = self.provider.client_credentials_grant(scopes=scopes)
|
||||
self._save_token_to_store(token)
|
||||
self._cached_token = token
|
||||
return token
|
||||
|
||||
# --- Helper Methods ---
|
||||
|
||||
def _needs_refresh(self, token: OAuth2Token) -> bool:
|
||||
"""Check if token needs refresh."""
|
||||
if token.expires_at is None:
|
||||
return False
|
||||
return datetime.now(UTC) >= (token.expires_at - self.refresh_buffer)
|
||||
|
||||
def _credential_to_token(self, credential: CredentialObject) -> OAuth2Token | None:
|
||||
"""Convert credential to OAuth2Token."""
|
||||
access_token = credential.get_key("access_token")
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
expires_at = None
|
||||
access_key = credential.keys.get("access_token")
|
||||
if access_key:
|
||||
expires_at = access_key.expires_at
|
||||
|
||||
return OAuth2Token(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_at=expires_at,
|
||||
refresh_token=credential.get_key("refresh_token"),
|
||||
scope=credential.get_key("scope"),
|
||||
)
|
||||
|
||||
def _save_token_to_store(self, token: OAuth2Token) -> None:
|
||||
"""Save token to credential store."""
|
||||
credential = CredentialObject(
|
||||
id=self.credential_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(
|
||||
name="access_token",
|
||||
value=SecretStr(token.access_token),
|
||||
expires_at=token.expires_at,
|
||||
),
|
||||
},
|
||||
provider_id=self.provider.provider_id,
|
||||
auto_refresh=True,
|
||||
)
|
||||
|
||||
if token.refresh_token:
|
||||
credential.keys["refresh_token"] = CredentialKey(
|
||||
name="refresh_token",
|
||||
value=SecretStr(token.refresh_token),
|
||||
)
|
||||
|
||||
if token.scope:
|
||||
credential.keys["scope"] = CredentialKey(
|
||||
name="scope",
|
||||
value=SecretStr(token.scope),
|
||||
)
|
||||
|
||||
self.store.save_credential(credential)
|
||||
|
||||
async def _async_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
|
||||
"""Async wrapper for token refresh."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, lambda: self._sync_refresh_token(credential))
|
||||
|
||||
def _sync_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
|
||||
"""Synchronously refresh token."""
|
||||
refresh_token = credential.get_key("refresh_token")
|
||||
if not refresh_token:
|
||||
return TokenRefreshResult(
|
||||
success=False,
|
||||
error="No refresh token available",
|
||||
needs_reauthorization=True,
|
||||
)
|
||||
|
||||
try:
|
||||
new_token = self.provider.refresh_access_token(refresh_token)
|
||||
|
||||
# Save to store
|
||||
self._save_token_to_store(new_token)
|
||||
|
||||
# Notify callback
|
||||
if self.on_token_refreshed:
|
||||
self.on_token_refreshed(new_token)
|
||||
|
||||
logger.info(f"Token refreshed for {self.credential_id}")
|
||||
return TokenRefreshResult(success=True, token=new_token)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Check for refresh token revocation
|
||||
if "invalid_grant" in error_msg.lower():
|
||||
return TokenRefreshResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
needs_reauthorization=True,
|
||||
)
|
||||
|
||||
if self.on_refresh_failed:
|
||||
self.on_refresh_failed(error_msg)
|
||||
|
||||
logger.error(f"Token refresh failed for {self.credential_id}: {e}")
|
||||
return TokenRefreshResult(success=False, error=error_msg)
|
||||
|
||||
def invalidate_cache(self) -> None:
|
||||
"""Clear cached token."""
|
||||
self._cached_token = None
|
||||
self._cache_time = None
|
||||
|
||||
# --- Convenience Methods ---
|
||||
|
||||
def get_request_headers(self) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for HTTP request with current token.
|
||||
|
||||
Returns empty dict if no valid token.
|
||||
"""
|
||||
token = self.sync_get_valid_token()
|
||||
if token is None:
|
||||
return {}
|
||||
|
||||
result = self.provider.format_for_request(token)
|
||||
return result.get("headers", {})
|
||||
|
||||
def get_request_kwargs(self) -> dict:
|
||||
"""
|
||||
Get kwargs for HTTP request (headers, params, etc.).
|
||||
|
||||
Returns empty dict if no valid token.
|
||||
"""
|
||||
token = self.sync_get_valid_token()
|
||||
if token is None:
|
||||
return {}
|
||||
|
||||
return self.provider.format_for_request(token)
|
||||
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
OAuth2 types and configuration.
|
||||
|
||||
This module defines the core OAuth2 data structures:
|
||||
- OAuth2Token: Represents an access token with metadata
|
||||
- OAuth2Config: Configuration for OAuth2 endpoints
|
||||
- TokenPlacement: Where to place tokens in requests
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TokenPlacement(str, Enum):
|
||||
"""Where to place the access token in HTTP requests."""
|
||||
|
||||
HEADER_BEARER = "header_bearer"
|
||||
"""Authorization: Bearer <token> (most common)"""
|
||||
|
||||
HEADER_CUSTOM = "header_custom"
|
||||
"""Custom header name (e.g., X-Access-Token)"""
|
||||
|
||||
QUERY_PARAM = "query_param"
|
||||
"""Query parameter (e.g., ?access_token=<token>)"""
|
||||
|
||||
BODY_PARAM = "body_param"
|
||||
"""Form body parameter"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuth2Token:
|
||||
"""
|
||||
Represents an OAuth2 token with metadata.
|
||||
|
||||
Attributes:
|
||||
access_token: The access token string
|
||||
token_type: Token type (usually "Bearer")
|
||||
expires_at: When the token expires
|
||||
refresh_token: Optional refresh token
|
||||
scope: Granted scopes (space-separated)
|
||||
raw_response: Original token response from server
|
||||
"""
|
||||
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_at: datetime | None = None
|
||||
refresh_token: str | None = None
|
||||
scope: str | None = None
|
||||
raw_response: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
Check if token is expired.
|
||||
|
||||
Uses a 5-minute buffer to account for clock skew and
|
||||
request latency.
|
||||
"""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
buffer = timedelta(minutes=5)
|
||||
return datetime.now(UTC) >= (self.expires_at - buffer)
|
||||
|
||||
@property
|
||||
def can_refresh(self) -> bool:
|
||||
"""Check if token can be refreshed (has refresh_token)."""
|
||||
return self.refresh_token is not None and self.refresh_token.strip() != ""
|
||||
|
||||
@property
|
||||
def expires_in_seconds(self) -> int | None:
|
||||
"""Get seconds until expiration, or None if no expiration."""
|
||||
if self.expires_at is None:
|
||||
return None
|
||||
delta = self.expires_at - datetime.now(UTC)
|
||||
return max(0, int(delta.total_seconds()))
|
||||
|
||||
@classmethod
|
||||
def from_token_response(cls, data: dict[str, Any]) -> OAuth2Token:
|
||||
"""
|
||||
Create OAuth2Token from an OAuth2 token endpoint response.
|
||||
|
||||
Args:
|
||||
data: Token response JSON (access_token, token_type, expires_in, etc.)
|
||||
|
||||
Returns:
|
||||
OAuth2Token instance
|
||||
"""
|
||||
expires_at = None
|
||||
if "expires_in" in data:
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=data["expires_in"])
|
||||
|
||||
return cls(
|
||||
access_token=data["access_token"],
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_at=expires_at,
|
||||
refresh_token=data.get("refresh_token"),
|
||||
scope=data.get("scope"),
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuth2Config:
|
||||
"""
|
||||
Configuration for an OAuth2 provider.
|
||||
|
||||
This contains all the information needed to perform OAuth2 operations
|
||||
for a specific provider (GitHub, Google, Salesforce, etc.).
|
||||
|
||||
Attributes:
|
||||
token_url: URL for token endpoint (required)
|
||||
authorization_url: URL for authorization endpoint (optional, for auth code flow)
|
||||
revocation_url: URL for token revocation (optional)
|
||||
introspection_url: URL for token introspection (optional)
|
||||
client_id: OAuth2 client ID
|
||||
client_secret: OAuth2 client secret
|
||||
default_scopes: Default scopes to request
|
||||
token_placement: How to include token in requests
|
||||
custom_header_name: Header name when using HEADER_CUSTOM placement
|
||||
query_param_name: Query param name when using QUERY_PARAM placement
|
||||
extra_token_params: Additional parameters for token requests
|
||||
request_timeout: Timeout for HTTP requests in seconds
|
||||
|
||||
Example:
|
||||
config = OAuth2Config(
|
||||
token_url="https://github.com/login/oauth/access_token",
|
||||
authorization_url="https://github.com/login/oauth/authorize",
|
||||
client_id="your-client-id",
|
||||
client_secret="your-client-secret",
|
||||
default_scopes=["repo", "user"],
|
||||
)
|
||||
"""
|
||||
|
||||
# Endpoints (only token_url is strictly required)
|
||||
token_url: str
|
||||
authorization_url: str | None = None
|
||||
revocation_url: str | None = None
|
||||
introspection_url: str | None = None
|
||||
|
||||
# Client credentials
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
|
||||
# Scopes
|
||||
default_scopes: list[str] = field(default_factory=list)
|
||||
|
||||
# Token placement for API calls (bipartisan model)
|
||||
token_placement: TokenPlacement = TokenPlacement.HEADER_BEARER
|
||||
custom_header_name: str | None = None
|
||||
query_param_name: str = "access_token"
|
||||
|
||||
# Request configuration
|
||||
extra_token_params: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout: float = 30.0
|
||||
|
||||
# Additional headers for token requests
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration."""
|
||||
if not self.token_url:
|
||||
raise ValueError("token_url is required")
|
||||
|
||||
if self.token_placement == TokenPlacement.HEADER_CUSTOM and not self.custom_header_name:
|
||||
raise ValueError("custom_header_name is required when using HEADER_CUSTOM placement")
|
||||
|
||||
|
||||
class OAuth2Error(Exception):
|
||||
"""
|
||||
OAuth2 protocol error.
|
||||
|
||||
Attributes:
|
||||
error: OAuth2 error code (e.g., 'invalid_grant', 'invalid_client')
|
||||
description: Human-readable error description
|
||||
status_code: HTTP status code from the response
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: str,
|
||||
description: str = "",
|
||||
status_code: int = 0,
|
||||
):
|
||||
self.error = error
|
||||
self.description = description
|
||||
self.status_code = status_code
|
||||
super().__init__(f"{error}: {description}" if description else error)
|
||||
|
||||
|
||||
class TokenExpiredError(OAuth2Error):
|
||||
"""Raised when a token has expired and cannot be used."""
|
||||
|
||||
def __init__(self, credential_id: str):
|
||||
super().__init__(
|
||||
error="token_expired",
|
||||
description=f"Token for '{credential_id}' has expired",
|
||||
)
|
||||
self.credential_id = credential_id
|
||||
|
||||
|
||||
class RefreshTokenInvalidError(OAuth2Error):
|
||||
"""Raised when the refresh token is invalid or revoked."""
|
||||
|
||||
def __init__(self, credential_id: str, reason: str = ""):
|
||||
description = f"Refresh token for '{credential_id}' is invalid"
|
||||
if reason:
|
||||
description += f": {reason}"
|
||||
super().__init__(error="invalid_grant", description=description)
|
||||
self.credential_id = credential_id
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Provider interface for credential lifecycle management.
|
||||
|
||||
Providers handle credential lifecycle operations:
|
||||
- Refresh: Obtain new tokens when expired
|
||||
- Validate: Check if credentials are still working
|
||||
- Revoke: Invalidate credentials when no longer needed
|
||||
|
||||
OSS users can implement custom providers by subclassing CredentialProvider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from .models import CredentialObject, CredentialRefreshError, CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialProvider(ABC):
|
||||
"""
|
||||
Abstract base class for credential providers.
|
||||
|
||||
Providers handle credential lifecycle operations:
|
||||
- refresh(): Obtain new tokens when expired
|
||||
- validate(): Check if credentials are still working
|
||||
- should_refresh(): Determine if a credential needs refresh
|
||||
- revoke(): Invalidate credentials (optional)
|
||||
|
||||
Example custom provider:
|
||||
class MyCustomProvider(CredentialProvider):
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return "my_custom"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> List[CredentialType]:
|
||||
return [CredentialType.CUSTOM]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
# Custom refresh logic
|
||||
new_token = my_api.refresh(credential.get_key("api_key"))
|
||||
credential.set_key("access_token", new_token)
|
||||
return credential
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
token = credential.get_key("access_token")
|
||||
return my_api.validate(token)
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_id(self) -> str:
|
||||
"""
|
||||
Unique identifier for this provider.
|
||||
|
||||
Examples: 'static', 'oauth2', 'my_custom_auth'
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
Credential types this provider can manage.
|
||||
|
||||
Returns:
|
||||
List of CredentialType enums this provider supports
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Refresh the credential (e.g., use refresh_token to get new access_token).
|
||||
|
||||
This method should:
|
||||
1. Use existing credential data to obtain new values
|
||||
2. Update the credential object with new values
|
||||
3. Set appropriate expiration times
|
||||
4. Update last_refreshed timestamp
|
||||
|
||||
Args:
|
||||
credential: The credential to refresh
|
||||
|
||||
Returns:
|
||||
Updated credential with new values
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate that a credential is still working.
|
||||
|
||||
This might involve:
|
||||
- Checking expiration times
|
||||
- Making a test API call
|
||||
- Validating token signatures
|
||||
|
||||
Args:
|
||||
credential: The credential to validate
|
||||
|
||||
Returns:
|
||||
True if credential is valid, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Determine if a credential should be refreshed.
|
||||
|
||||
Default implementation: refresh if any key is expired or within
|
||||
5 minutes of expiry. Override for custom logic.
|
||||
|
||||
Args:
|
||||
credential: The credential to check
|
||||
|
||||
Returns:
|
||||
True if credential should be refreshed
|
||||
"""
|
||||
buffer = timedelta(minutes=5)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for key in credential.keys.values():
|
||||
if key.expires_at is not None:
|
||||
if key.expires_at <= now + buffer:
|
||||
return True
|
||||
return False
|
||||
|
||||
def revoke(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Revoke a credential (optional operation).
|
||||
|
||||
Not all providers support revocation. The default implementation
|
||||
logs a warning and returns False.
|
||||
|
||||
Args:
|
||||
credential: The credential to revoke
|
||||
|
||||
Returns:
|
||||
True if revocation succeeded, False otherwise
|
||||
"""
|
||||
logger.warning(f"Provider '{self.provider_id}' does not support revocation")
|
||||
return False
|
||||
|
||||
def can_handle(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if this provider can handle a credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to check
|
||||
|
||||
Returns:
|
||||
True if this provider can manage the credential
|
||||
"""
|
||||
return credential.credential_type in self.supported_types
|
||||
|
||||
|
||||
class StaticProvider(CredentialProvider):
|
||||
"""
|
||||
Provider for static credentials that never need refresh.
|
||||
|
||||
Use for simple API keys that don't expire, such as:
|
||||
- Brave Search API key
|
||||
- OpenAI API key
|
||||
- Basic auth credentials
|
||||
|
||||
Static credentials are always considered valid if they have at least one key.
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return "static"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.API_KEY, CredentialType.BASIC_AUTH, CredentialType.CUSTOM]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Static credentials don't need refresh.
|
||||
|
||||
Returns the credential unchanged.
|
||||
"""
|
||||
logger.debug(f"Static credential '{credential.id}' does not need refresh")
|
||||
return credential
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate that credential has at least one key with a value.
|
||||
|
||||
For static credentials, we can't verify the key works without
|
||||
making an API call, so we just check existence.
|
||||
"""
|
||||
if not credential.keys:
|
||||
return False
|
||||
|
||||
# Check at least one key has a non-empty value
|
||||
for key in credential.keys.values():
|
||||
try:
|
||||
value = key.get_secret_value()
|
||||
if value and value.strip():
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""Static credentials never need refresh."""
|
||||
return False
|
||||
|
||||
|
||||
class BearerTokenProvider(CredentialProvider):
|
||||
"""
|
||||
Provider for bearer tokens without refresh capability.
|
||||
|
||||
Use for JWTs or tokens that:
|
||||
- Have an expiration time
|
||||
- Cannot be refreshed (no refresh token)
|
||||
- Must be re-obtained when expired
|
||||
|
||||
This provider validates based on expiration time only.
|
||||
"""
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return "bearer_token"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.BEARER_TOKEN]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""
|
||||
Bearer tokens without refresh capability cannot be refreshed.
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: Always, as refresh is not supported
|
||||
"""
|
||||
raise CredentialRefreshError(
|
||||
f"Bearer token '{credential.id}' cannot be refreshed. "
|
||||
"Obtain a new token and save it to the credential store."
|
||||
)
|
||||
|
||||
def validate(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Validate based on expiration time.
|
||||
|
||||
Returns True if token exists and is not expired.
|
||||
"""
|
||||
access_key = credential.keys.get("access_token") or credential.keys.get("token")
|
||||
if access_key is None:
|
||||
return False
|
||||
|
||||
# Check if expired
|
||||
return not access_key.is_expired
|
||||
|
||||
def should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
Check if token is expired or near expiration.
|
||||
|
||||
Note: Even though this returns True for expired tokens,
|
||||
refresh() will fail. This allows the store to know the
|
||||
credential needs attention.
|
||||
"""
|
||||
buffer = timedelta(minutes=5)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for key_name in ["access_token", "token"]:
|
||||
key = credential.keys.get(key_name)
|
||||
if key and key.expires_at:
|
||||
if key.expires_at <= now + buffer:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,516 @@
|
||||
"""
|
||||
Storage backends for the credential store.
|
||||
|
||||
This module provides abstract and concrete storage implementations:
|
||||
- CredentialStorage: Abstract base class
|
||||
- EncryptedFileStorage: Fernet-encrypted JSON files (default for production)
|
||||
- EnvVarStorage: Environment variable reading (backward compatibility)
|
||||
- InMemoryStorage: For testing
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .models import CredentialDecryptionError, CredentialKey, CredentialObject, CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialStorage(ABC):
|
||||
"""
|
||||
Abstract storage backend for credentials.
|
||||
|
||||
Implementations must provide save, load, delete, list_all, and exists methods.
|
||||
All implementations should handle serialization of SecretStr values securely.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save a credential to storage.
|
||||
|
||||
Args:
|
||||
credential: The credential object to save
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load a credential from storage.
|
||||
|
||||
Args:
|
||||
credential_id: The ID of the credential to load
|
||||
|
||||
Returns:
|
||||
CredentialObject if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Delete a credential from storage.
|
||||
|
||||
Args:
|
||||
credential_id: The ID of the credential to delete
|
||||
|
||||
Returns:
|
||||
True if the credential existed and was deleted, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all(self) -> list[str]:
|
||||
"""
|
||||
List all credential IDs in storage.
|
||||
|
||||
Returns:
|
||||
List of credential IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if a credential exists in storage.
|
||||
|
||||
Args:
|
||||
credential_id: The ID to check
|
||||
|
||||
Returns:
|
||||
True if credential exists, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EncryptedFileStorage(CredentialStorage):
|
||||
"""
|
||||
Encrypted file-based credential storage.
|
||||
|
||||
Uses Fernet symmetric encryption (AES-128-CBC + HMAC) for at-rest encryption.
|
||||
Each credential is stored as a separate encrypted JSON file.
|
||||
|
||||
Directory structure:
|
||||
{base_path}/
|
||||
credentials/
|
||||
{credential_id}.enc # Encrypted credential JSON
|
||||
metadata/
|
||||
index.json # Index of all credentials (unencrypted)
|
||||
|
||||
The encryption key is read from the HIVE_CREDENTIAL_KEY environment variable.
|
||||
If not set, a new key is generated (and must be persisted for data recovery).
|
||||
|
||||
Example:
|
||||
storage = EncryptedFileStorage("/var/hive/credentials")
|
||||
storage.save(credential)
|
||||
credential = storage.load("brave_search")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_path: str | Path,
|
||||
encryption_key: bytes | None = None,
|
||||
key_env_var: str = "HIVE_CREDENTIAL_KEY",
|
||||
):
|
||||
"""
|
||||
Initialize encrypted storage.
|
||||
|
||||
Args:
|
||||
base_path: Directory for credential files
|
||||
encryption_key: 32-byte Fernet key. If None, reads from env var.
|
||||
key_env_var: Environment variable containing encryption key
|
||||
"""
|
||||
try:
|
||||
from cryptography.fernet import Fernet
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Encrypted storage requires 'cryptography'. Install with: pip install cryptography"
|
||||
) from e
|
||||
|
||||
self.base_path = Path(base_path)
|
||||
self._ensure_dirs()
|
||||
self._key_env_var = key_env_var
|
||||
|
||||
# Get or generate encryption key
|
||||
if encryption_key:
|
||||
self._key = encryption_key
|
||||
else:
|
||||
key_str = os.environ.get(key_env_var)
|
||||
if key_str:
|
||||
self._key = key_str.encode()
|
||||
else:
|
||||
# Generate new key
|
||||
self._key = Fernet.generate_key()
|
||||
logger.warning(
|
||||
f"Generated new encryption key. To persist credentials across restarts, "
|
||||
f"set {key_env_var}={self._key.decode()}"
|
||||
)
|
||||
|
||||
self._fernet = Fernet(self._key)
|
||||
|
||||
def _ensure_dirs(self) -> None:
|
||||
"""Create directory structure."""
|
||||
(self.base_path / "credentials").mkdir(parents=True, exist_ok=True)
|
||||
(self.base_path / "metadata").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _cred_path(self, credential_id: str) -> Path:
|
||||
"""Get the file path for a credential."""
|
||||
# Sanitize credential_id to prevent path traversal
|
||||
safe_id = credential_id.replace("/", "_").replace("\\", "_").replace("..", "_")
|
||||
return self.base_path / "credentials" / f"{safe_id}.enc"
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Encrypt and save credential."""
|
||||
# Serialize credential
|
||||
data = self._serialize_credential(credential)
|
||||
json_bytes = json.dumps(data, default=str).encode()
|
||||
|
||||
# Encrypt
|
||||
encrypted = self._fernet.encrypt(json_bytes)
|
||||
|
||||
# Write to file
|
||||
cred_path = self._cred_path(credential.id)
|
||||
with open(cred_path, "wb") as f:
|
||||
f.write(encrypted)
|
||||
|
||||
# Update index
|
||||
self._update_index(credential.id, "save", credential.credential_type.value)
|
||||
logger.debug(f"Saved encrypted credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load and decrypt credential."""
|
||||
cred_path = self._cred_path(credential_id)
|
||||
if not cred_path.exists():
|
||||
return None
|
||||
|
||||
# Read encrypted data
|
||||
with open(cred_path, "rb") as f:
|
||||
encrypted = f.read()
|
||||
|
||||
# Decrypt
|
||||
try:
|
||||
json_bytes = self._fernet.decrypt(encrypted)
|
||||
data = json.loads(json_bytes.decode())
|
||||
except Exception as e:
|
||||
raise CredentialDecryptionError(
|
||||
f"Failed to decrypt credential '{credential_id}': {e}"
|
||||
) from e
|
||||
|
||||
# Deserialize
|
||||
return self._deserialize_credential(data)
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete a credential file."""
|
||||
cred_path = self._cred_path(credential_id)
|
||||
if cred_path.exists():
|
||||
cred_path.unlink()
|
||||
self._update_index(credential_id, "delete")
|
||||
logger.debug(f"Deleted credential '{credential_id}'")
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credential IDs."""
|
||||
index_path = self.base_path / "metadata" / "index.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
return list(index.get("credentials", {}).keys())
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists."""
|
||||
return self._cred_path(credential_id).exists()
|
||||
|
||||
def _serialize_credential(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""Convert credential to JSON-serializable dict, extracting secret values."""
|
||||
data = credential.model_dump(mode="json")
|
||||
|
||||
# Extract actual secret values from SecretStr
|
||||
for key_name, key_data in data.get("keys", {}).items():
|
||||
if "value" in key_data:
|
||||
# SecretStr serializes as "**********", need actual value
|
||||
actual_key = credential.keys.get(key_name)
|
||||
if actual_key:
|
||||
key_data["value"] = actual_key.get_secret_value()
|
||||
|
||||
return data
|
||||
|
||||
def _deserialize_credential(self, data: dict[str, Any]) -> CredentialObject:
|
||||
"""Reconstruct credential from dict, wrapping values in SecretStr."""
|
||||
# Convert plain values back to SecretStr
|
||||
for key_data in data.get("keys", {}).values():
|
||||
if "value" in key_data and isinstance(key_data["value"], str):
|
||||
key_data["value"] = SecretStr(key_data["value"])
|
||||
|
||||
return CredentialObject.model_validate(data)
|
||||
|
||||
def _update_index(
|
||||
self,
|
||||
credential_id: str,
|
||||
operation: str,
|
||||
credential_type: str | None = None,
|
||||
) -> None:
|
||||
"""Update the metadata index."""
|
||||
index_path = self.base_path / "metadata" / "index.json"
|
||||
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
else:
|
||||
index = {"credentials": {}, "version": "1.0"}
|
||||
|
||||
if operation == "save":
|
||||
index["credentials"][credential_id] = {
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
"type": credential_type,
|
||||
}
|
||||
elif operation == "delete":
|
||||
index["credentials"].pop(credential_id, None)
|
||||
|
||||
index["last_modified"] = datetime.now(UTC).isoformat()
|
||||
|
||||
with open(index_path, "w") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
|
||||
|
||||
class EnvVarStorage(CredentialStorage):
|
||||
"""
|
||||
Environment variable-based storage for backward compatibility.
|
||||
|
||||
Maps credential IDs to environment variable patterns.
|
||||
Supports hot-reload from .env files using python-dotenv.
|
||||
|
||||
This storage is READ-ONLY - credentials cannot be saved at runtime.
|
||||
|
||||
Example:
|
||||
storage = EnvVarStorage(
|
||||
env_mapping={"brave_search": "BRAVE_SEARCH_API_KEY"},
|
||||
dotenv_path=Path(".env")
|
||||
)
|
||||
credential = storage.load("brave_search")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_mapping: dict[str, str] | None = None,
|
||||
dotenv_path: Path | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize env var storage.
|
||||
|
||||
Args:
|
||||
env_mapping: Map of credential_id -> env_var_name
|
||||
e.g., {"brave_search": "BRAVE_SEARCH_API_KEY"}
|
||||
If not provided, uses {CREDENTIAL_ID}_API_KEY pattern
|
||||
dotenv_path: Path to .env file for hot-reload support
|
||||
"""
|
||||
self._env_mapping = env_mapping or {}
|
||||
self._dotenv_path = dotenv_path or Path.cwd() / ".env"
|
||||
|
||||
def _get_env_var_name(self, credential_id: str) -> str:
|
||||
"""Get the environment variable name for a credential."""
|
||||
if credential_id in self._env_mapping:
|
||||
return self._env_mapping[credential_id]
|
||||
# Default pattern: CREDENTIAL_ID_API_KEY
|
||||
return f"{credential_id.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
def _read_env_value(self, env_var: str) -> str | None:
|
||||
"""Read value from env var or .env file."""
|
||||
# Check os.environ first (takes precedence)
|
||||
value = os.environ.get(env_var)
|
||||
if value:
|
||||
return value
|
||||
|
||||
# Fallback: read from .env file (hot-reload)
|
||||
if self._dotenv_path.exists():
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
|
||||
values = dotenv_values(self._dotenv_path)
|
||||
return values.get(env_var)
|
||||
except ImportError:
|
||||
logger.debug("python-dotenv not installed, skipping .env file")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Cannot save to environment variables at runtime."""
|
||||
raise NotImplementedError(
|
||||
"EnvVarStorage is read-only. Set environment variables "
|
||||
"externally or use EncryptedFileStorage."
|
||||
)
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from environment variable."""
|
||||
env_var = self._get_env_var_name(credential_id)
|
||||
value = self._read_env_value(env_var)
|
||||
|
||||
if not value:
|
||||
return None
|
||||
|
||||
return CredentialObject(
|
||||
id=credential_id,
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr(value))},
|
||||
description=f"Loaded from {env_var}",
|
||||
)
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Cannot delete environment variables at runtime."""
|
||||
raise NotImplementedError(
|
||||
"EnvVarStorage is read-only. Unset environment variables externally."
|
||||
)
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List credentials that are available in environment."""
|
||||
available = []
|
||||
|
||||
# Check mapped credentials
|
||||
for cred_id in self._env_mapping.keys():
|
||||
if self.exists(cred_id):
|
||||
available.append(cred_id)
|
||||
|
||||
return available
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential is available in environment."""
|
||||
env_var = self._get_env_var_name(credential_id)
|
||||
return self._read_env_value(env_var) is not None
|
||||
|
||||
def add_mapping(self, credential_id: str, env_var: str) -> None:
|
||||
"""
|
||||
Add a credential ID to environment variable mapping.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
env_var: The environment variable name
|
||||
"""
|
||||
self._env_mapping[credential_id] = env_var
|
||||
|
||||
|
||||
class InMemoryStorage(CredentialStorage):
|
||||
"""
|
||||
In-memory storage for testing.
|
||||
|
||||
Credentials are stored in a dictionary and lost when the process exits.
|
||||
|
||||
Example:
|
||||
storage = InMemoryStorage()
|
||||
storage.save(credential)
|
||||
credential = storage.load("test_cred")
|
||||
"""
|
||||
|
||||
def __init__(self, initial_data: dict[str, CredentialObject] | None = None):
|
||||
"""
|
||||
Initialize in-memory storage.
|
||||
|
||||
Args:
|
||||
initial_data: Optional dict of credential_id -> CredentialObject
|
||||
"""
|
||||
self._data: dict[str, CredentialObject] = initial_data or {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Save credential to memory."""
|
||||
self._data[credential.id] = credential
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from memory."""
|
||||
return self._data.get(credential_id)
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete credential from memory."""
|
||||
if credential_id in self._data:
|
||||
del self._data[credential_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credential IDs."""
|
||||
return list(self._data.keys())
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists."""
|
||||
return credential_id in self._data
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all credentials."""
|
||||
self._data.clear()
|
||||
|
||||
|
||||
class CompositeStorage(CredentialStorage):
|
||||
"""
|
||||
Composite storage that reads from multiple backends.
|
||||
|
||||
Useful for layering storages, e.g., encrypted file with env var fallback:
|
||||
- Writes go to the primary storage
|
||||
- Reads check primary first, then fallback storages
|
||||
|
||||
Example:
|
||||
storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage("/var/hive/credentials"),
|
||||
fallbacks=[EnvVarStorage({"brave_search": "BRAVE_SEARCH_API_KEY"})]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
primary: CredentialStorage,
|
||||
fallbacks: list[CredentialStorage] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize composite storage.
|
||||
|
||||
Args:
|
||||
primary: Primary storage for writes and first read attempt
|
||||
fallbacks: List of fallback storages to check if primary doesn't have credential
|
||||
"""
|
||||
self._primary = primary
|
||||
self._fallbacks = fallbacks or []
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Save to primary storage."""
|
||||
self._primary.save(credential)
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load from primary, then fallbacks."""
|
||||
# Try primary first
|
||||
credential = self._primary.load(credential_id)
|
||||
if credential is not None:
|
||||
return credential
|
||||
|
||||
# Try fallbacks
|
||||
for fallback in self._fallbacks:
|
||||
credential = fallback.load(credential_id)
|
||||
if credential is not None:
|
||||
return credential
|
||||
|
||||
return None
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete from primary storage only."""
|
||||
return self._primary.delete(credential_id)
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List credentials from all storages."""
|
||||
all_ids = set(self._primary.list_all())
|
||||
for fallback in self._fallbacks:
|
||||
all_ids.update(fallback.list_all())
|
||||
return list(all_ids)
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists in any storage."""
|
||||
if self._primary.exists(credential_id):
|
||||
return True
|
||||
return any(fallback.exists(credential_id) for fallback in self._fallbacks)
|
||||
@@ -0,0 +1,708 @@
|
||||
"""
|
||||
Main credential store orchestrating storage, providers, and template resolution.
|
||||
|
||||
The CredentialStore is the primary interface for credential management, providing:
|
||||
- Multi-backend storage (file, env, vault)
|
||||
- Provider-based lifecycle management (refresh, validate)
|
||||
- Template resolution for {{cred.key}} patterns
|
||||
- Caching with TTL for performance
|
||||
- Thread-safe operations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .models import (
|
||||
CredentialKey,
|
||||
CredentialObject,
|
||||
CredentialRefreshError,
|
||||
CredentialUsageSpec,
|
||||
)
|
||||
from .provider import CredentialProvider, StaticProvider
|
||||
from .storage import CredentialStorage, EnvVarStorage, InMemoryStorage
|
||||
from .template import TemplateResolver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialStore:
|
||||
"""
|
||||
Main credential store orchestrating storage, providers, and template resolution.
|
||||
|
||||
Features:
|
||||
- Multi-backend storage (file, env, vault)
|
||||
- Provider-based lifecycle management (refresh, validate)
|
||||
- Template resolution for {{cred.key}} patterns
|
||||
- Caching with TTL for performance
|
||||
- Thread-safe operations
|
||||
|
||||
Usage:
|
||||
# Basic usage
|
||||
store = CredentialStore(
|
||||
storage=EncryptedFileStorage("/path/to/creds"),
|
||||
providers=[OAuth2Provider(), StaticProvider()]
|
||||
)
|
||||
|
||||
# Get a credential
|
||||
cred = store.get_credential("github_oauth")
|
||||
|
||||
# Resolve templates in headers
|
||||
headers = store.resolve_headers({
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}"
|
||||
})
|
||||
|
||||
# Register a tool's credential requirements
|
||||
store.register_usage(CredentialUsageSpec(
|
||||
credential_id="brave_search",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Subscription-Token": "{{brave_search.api_key}}"}
|
||||
))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: CredentialStorage | None = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
cache_ttl_seconds: int = 300,
|
||||
auto_refresh: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the credential store.
|
||||
|
||||
Args:
|
||||
storage: Storage backend. Defaults to EnvVarStorage for compatibility.
|
||||
providers: List of credential providers. Defaults to [StaticProvider()].
|
||||
cache_ttl_seconds: How long to cache credentials in memory (default: 5 minutes).
|
||||
auto_refresh: Whether to auto-refresh expired credentials on access.
|
||||
"""
|
||||
self._storage = storage or EnvVarStorage()
|
||||
self._providers: dict[str, CredentialProvider] = {}
|
||||
self._usage_specs: dict[str, CredentialUsageSpec] = {}
|
||||
|
||||
# Cache: credential_id -> (CredentialObject, cached_at)
|
||||
self._cache: dict[str, tuple[CredentialObject, datetime]] = {}
|
||||
self._cache_ttl = cache_ttl_seconds
|
||||
self._lock = threading.RLock()
|
||||
|
||||
self._auto_refresh = auto_refresh
|
||||
|
||||
# Register providers
|
||||
for provider in providers or [StaticProvider()]:
|
||||
self.register_provider(provider)
|
||||
|
||||
# Template resolver
|
||||
self._resolver = TemplateResolver(self)
|
||||
|
||||
# --- Provider Management ---
|
||||
|
||||
def register_provider(self, provider: CredentialProvider) -> None:
|
||||
"""
|
||||
Register a credential provider.
|
||||
|
||||
Args:
|
||||
provider: The provider to register
|
||||
"""
|
||||
self._providers[provider.provider_id] = provider
|
||||
logger.debug(f"Registered credential provider: {provider.provider_id}")
|
||||
|
||||
def get_provider(self, provider_id: str) -> CredentialProvider | None:
|
||||
"""
|
||||
Get a provider by ID.
|
||||
|
||||
Args:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
The provider if found, None otherwise
|
||||
"""
|
||||
return self._providers.get(provider_id)
|
||||
|
||||
def get_provider_for_credential(
|
||||
self, credential: CredentialObject
|
||||
) -> CredentialProvider | None:
|
||||
"""
|
||||
Get the appropriate provider for a credential.
|
||||
|
||||
Args:
|
||||
credential: The credential to find a provider for
|
||||
|
||||
Returns:
|
||||
The provider if found, None otherwise
|
||||
"""
|
||||
# First, check if credential specifies a provider
|
||||
if credential.provider_id:
|
||||
provider = self._providers.get(credential.provider_id)
|
||||
if provider:
|
||||
return provider
|
||||
|
||||
# Fall back to finding a provider that supports this type
|
||||
for provider in self._providers.values():
|
||||
if provider.can_handle(credential):
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
# --- Usage Spec Management ---
|
||||
|
||||
def register_usage(self, spec: CredentialUsageSpec) -> None:
|
||||
"""
|
||||
Register how a tool uses credentials.
|
||||
|
||||
Args:
|
||||
spec: The usage specification
|
||||
"""
|
||||
self._usage_specs[spec.credential_id] = spec
|
||||
|
||||
def get_usage_spec(self, credential_id: str) -> CredentialUsageSpec | None:
|
||||
"""
|
||||
Get the usage spec for a credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
The usage spec if registered, None otherwise
|
||||
"""
|
||||
return self._usage_specs.get(credential_id)
|
||||
|
||||
# --- Credential Access ---
|
||||
|
||||
def get_credential(
|
||||
self,
|
||||
credential_id: str,
|
||||
refresh_if_needed: bool = True,
|
||||
) -> CredentialObject | None:
|
||||
"""
|
||||
Get a credential by ID.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
refresh_if_needed: If True, refresh expired credentials
|
||||
|
||||
Returns:
|
||||
CredentialObject or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
# Check cache
|
||||
cached = self._get_from_cache(credential_id)
|
||||
if cached is not None:
|
||||
if refresh_if_needed and self._should_refresh(cached):
|
||||
return self._refresh_credential(cached)
|
||||
return cached
|
||||
|
||||
# Load from storage
|
||||
credential = self._storage.load(credential_id)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
# Refresh if needed
|
||||
if refresh_if_needed and self._should_refresh(credential):
|
||||
credential = self._refresh_credential(credential)
|
||||
|
||||
# Cache
|
||||
self._add_to_cache(credential)
|
||||
|
||||
return credential
|
||||
|
||||
def get_key(self, credential_id: str, key_name: str) -> str | None:
|
||||
"""
|
||||
Convenience method to get a specific key value.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
key_name: The key within the credential
|
||||
|
||||
Returns:
|
||||
The key value or None if not found
|
||||
"""
|
||||
credential = self.get_credential(credential_id)
|
||||
if credential is None:
|
||||
return None
|
||||
return credential.get_key(key_name)
|
||||
|
||||
def get(self, credential_id: str) -> str | None:
|
||||
"""
|
||||
Legacy compatibility: get the primary key value.
|
||||
|
||||
For single-key credentials, returns that key.
|
||||
For multi-key, returns 'value', 'api_key', or 'access_token'.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
The primary key value or None
|
||||
"""
|
||||
credential = self.get_credential(credential_id)
|
||||
if credential is None:
|
||||
return None
|
||||
return credential.get_default_key()
|
||||
|
||||
# --- Template Resolution ---
|
||||
|
||||
def resolve(self, template: str) -> str:
|
||||
"""
|
||||
Resolve credential templates in a string.
|
||||
|
||||
Args:
|
||||
template: String containing {{cred.key}} patterns
|
||||
|
||||
Returns:
|
||||
Template with all references resolved
|
||||
|
||||
Example:
|
||||
>>> store.resolve("Bearer {{github.access_token}}")
|
||||
"Bearer ghp_xxxxxxxxxxxx"
|
||||
"""
|
||||
return self._resolver.resolve(template)
|
||||
|
||||
def resolve_headers(self, headers: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Resolve credential templates in headers dictionary.
|
||||
|
||||
Args:
|
||||
headers: Dict of header name to template value
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved
|
||||
|
||||
Example:
|
||||
>>> store.resolve_headers({
|
||||
... "Authorization": "Bearer {{github.access_token}}"
|
||||
... })
|
||||
{"Authorization": "Bearer ghp_xxx"}
|
||||
"""
|
||||
return self._resolver.resolve_headers(headers)
|
||||
|
||||
def resolve_params(self, params: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Resolve credential templates in query parameters dictionary.
|
||||
|
||||
Args:
|
||||
params: Dict of param name to template value
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved
|
||||
"""
|
||||
return self._resolver.resolve_params(params)
|
||||
|
||||
def resolve_for_usage(self, credential_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get resolved request kwargs for a registered usage spec.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
Dict with 'headers', 'params', etc. keys as appropriate
|
||||
|
||||
Raises:
|
||||
ValueError: If no usage spec is registered for the credential
|
||||
"""
|
||||
spec = self._usage_specs.get(credential_id)
|
||||
if spec is None:
|
||||
raise ValueError(f"No usage spec registered for '{credential_id}'")
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
if spec.headers:
|
||||
result["headers"] = self.resolve_headers(spec.headers)
|
||||
|
||||
if spec.query_params:
|
||||
result["params"] = self.resolve_params(spec.query_params)
|
||||
|
||||
if spec.body_fields:
|
||||
result["data"] = {key: self.resolve(value) for key, value in spec.body_fields.items()}
|
||||
|
||||
return result
|
||||
|
||||
# --- Credential Management ---
|
||||
|
||||
def save_credential(self, credential: CredentialObject) -> None:
|
||||
"""
|
||||
Save a credential to storage.
|
||||
|
||||
Args:
|
||||
credential: The credential to save
|
||||
"""
|
||||
with self._lock:
|
||||
self._storage.save(credential)
|
||||
self._add_to_cache(credential)
|
||||
logger.info(f"Saved credential '{credential.id}'")
|
||||
|
||||
def delete_credential(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Delete a credential from storage.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
True if the credential existed and was deleted
|
||||
"""
|
||||
with self._lock:
|
||||
self._remove_from_cache(credential_id)
|
||||
result = self._storage.delete(credential_id)
|
||||
if result:
|
||||
logger.info(f"Deleted credential '{credential_id}'")
|
||||
return result
|
||||
|
||||
def list_credentials(self) -> list[str]:
|
||||
"""
|
||||
List all available credential IDs.
|
||||
|
||||
Returns:
|
||||
List of credential IDs
|
||||
"""
|
||||
return self._storage.list_all()
|
||||
|
||||
def is_available(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Check if a credential is available.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
True if credential exists and is accessible
|
||||
"""
|
||||
return self.get_credential(credential_id, refresh_if_needed=False) is not None
|
||||
|
||||
# --- Validation ---
|
||||
|
||||
def validate_for_usage(self, credential_id: str) -> list[str]:
|
||||
"""
|
||||
Validate that a credential meets its usage spec requirements.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
List of missing keys or errors. Empty list if valid.
|
||||
"""
|
||||
spec = self._usage_specs.get(credential_id)
|
||||
if spec is None:
|
||||
return [] # No requirements registered
|
||||
|
||||
credential = self.get_credential(credential_id)
|
||||
if credential is None:
|
||||
return [f"Credential '{credential_id}' not found"]
|
||||
|
||||
errors = []
|
||||
for key_name in spec.required_keys:
|
||||
if not credential.has_key(key_name):
|
||||
errors.append(f"Missing required key '{key_name}'")
|
||||
|
||||
return errors
|
||||
|
||||
def validate_all(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Validate all registered usage specs.
|
||||
|
||||
Returns:
|
||||
Dict mapping credential_id to list of errors.
|
||||
Only includes credentials with errors.
|
||||
"""
|
||||
errors = {}
|
||||
for cred_id in self._usage_specs.keys():
|
||||
cred_errors = self.validate_for_usage(cred_id)
|
||||
if cred_errors:
|
||||
errors[cred_id] = cred_errors
|
||||
return errors
|
||||
|
||||
def validate_credential(self, credential_id: str) -> bool:
|
||||
"""
|
||||
Validate a credential using its provider.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
True if credential is valid
|
||||
"""
|
||||
credential = self.get_credential(credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return False
|
||||
|
||||
provider = self.get_provider_for_credential(credential)
|
||||
if provider is None:
|
||||
# No provider, assume valid if has keys
|
||||
return bool(credential.keys)
|
||||
|
||||
return provider.validate(credential)
|
||||
|
||||
# --- Lifecycle Management ---
|
||||
|
||||
def _should_refresh(self, credential: CredentialObject) -> bool:
|
||||
"""Check if credential should be refreshed."""
|
||||
if not self._auto_refresh:
|
||||
return False
|
||||
|
||||
if not credential.auto_refresh:
|
||||
return False
|
||||
|
||||
provider = self.get_provider_for_credential(credential)
|
||||
if provider is None:
|
||||
return False
|
||||
|
||||
return provider.should_refresh(credential)
|
||||
|
||||
def _refresh_credential(self, credential: CredentialObject) -> CredentialObject:
|
||||
"""Refresh a credential using its provider."""
|
||||
provider = self.get_provider_for_credential(credential)
|
||||
if provider is None:
|
||||
logger.warning(f"No provider found for credential '{credential.id}'")
|
||||
return credential
|
||||
|
||||
try:
|
||||
refreshed = provider.refresh(credential)
|
||||
refreshed.last_refreshed = datetime.now(UTC)
|
||||
|
||||
# Persist the refreshed credential
|
||||
self._storage.save(refreshed)
|
||||
self._add_to_cache(refreshed)
|
||||
|
||||
logger.info(f"Refreshed credential '{credential.id}'")
|
||||
return refreshed
|
||||
|
||||
except CredentialRefreshError as e:
|
||||
logger.error(f"Failed to refresh credential '{credential.id}': {e}")
|
||||
return credential
|
||||
|
||||
def refresh_credential(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Manually refresh a credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
The refreshed credential, or None if not found
|
||||
|
||||
Raises:
|
||||
CredentialRefreshError: If refresh fails
|
||||
"""
|
||||
credential = self.get_credential(credential_id, refresh_if_needed=False)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
return self._refresh_credential(credential)
|
||||
|
||||
# --- Caching ---
|
||||
|
||||
def _get_from_cache(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Get credential from cache if not expired."""
|
||||
if credential_id not in self._cache:
|
||||
return None
|
||||
|
||||
credential, cached_at = self._cache[credential_id]
|
||||
age = (datetime.now(UTC) - cached_at).total_seconds()
|
||||
|
||||
if age > self._cache_ttl:
|
||||
del self._cache[credential_id]
|
||||
return None
|
||||
|
||||
return credential
|
||||
|
||||
def _add_to_cache(self, credential: CredentialObject) -> None:
|
||||
"""Add credential to cache."""
|
||||
self._cache[credential.id] = (credential, datetime.now(UTC))
|
||||
|
||||
def _remove_from_cache(self, credential_id: str) -> None:
|
||||
"""Remove credential from cache."""
|
||||
self._cache.pop(credential_id, None)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the credential cache."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
# --- Factory Methods ---
|
||||
|
||||
@classmethod
|
||||
def for_testing(
|
||||
cls,
|
||||
credentials: dict[str, dict[str, str]],
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store for testing with mock credentials.
|
||||
|
||||
Args:
|
||||
credentials: Dict mapping credential_id to {key_name: value}
|
||||
e.g., {"brave_search": {"api_key": "test-key"}}
|
||||
|
||||
Returns:
|
||||
CredentialStore with in-memory credentials
|
||||
|
||||
Example:
|
||||
store = CredentialStore.for_testing({
|
||||
"brave_search": {"api_key": "test-brave-key"},
|
||||
"github_oauth": {
|
||||
"access_token": "test-token",
|
||||
"refresh_token": "test-refresh"
|
||||
}
|
||||
})
|
||||
"""
|
||||
# Convert test data to CredentialObjects
|
||||
cred_objects: dict[str, CredentialObject] = {}
|
||||
|
||||
for cred_id, keys in credentials.items():
|
||||
cred_objects[cred_id] = CredentialObject(
|
||||
id=cred_id,
|
||||
keys={k: CredentialKey(name=k, value=SecretStr(v)) for k, v in keys.items()},
|
||||
)
|
||||
|
||||
return cls(
|
||||
storage=InMemoryStorage(cred_objects),
|
||||
auto_refresh=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def with_encrypted_storage(
|
||||
cls,
|
||||
base_path: str,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with encrypted file storage.
|
||||
|
||||
Args:
|
||||
base_path: Directory for credential files
|
||||
providers: List of credential providers
|
||||
**kwargs: Additional arguments passed to CredentialStore
|
||||
|
||||
Returns:
|
||||
CredentialStore with EncryptedFileStorage
|
||||
"""
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
return cls(
|
||||
storage=EncryptedFileStorage(base_path),
|
||||
providers=providers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def with_env_storage(
|
||||
cls,
|
||||
env_mapping: dict[str, str] | None = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with environment variable storage.
|
||||
|
||||
Args:
|
||||
env_mapping: Map of credential_id -> env_var_name
|
||||
providers: List of credential providers
|
||||
**kwargs: Additional arguments passed to CredentialStore
|
||||
|
||||
Returns:
|
||||
CredentialStore with EnvVarStorage
|
||||
"""
|
||||
return cls(
|
||||
storage=EnvVarStorage(env_mapping),
|
||||
providers=providers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def with_aden_sync(
|
||||
cls,
|
||||
base_url: str = "https://hive.adenhq.com",
|
||||
cache_ttl_seconds: int = 300,
|
||||
local_path: str | None = None,
|
||||
auto_sync: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with Aden server sync.
|
||||
|
||||
Automatically syncs OAuth2 tokens from the Aden authentication server.
|
||||
Falls back to local-only storage if ADEN_API_KEY is not set or Aden
|
||||
is unreachable.
|
||||
|
||||
Args:
|
||||
base_url: Aden server URL (default: https://hive.adenhq.com)
|
||||
cache_ttl_seconds: How long to cache credentials locally (default: 5 min)
|
||||
local_path: Path for local credential storage (default: ~/.hive/credentials)
|
||||
auto_sync: Whether to sync all credentials on startup (default: True)
|
||||
**kwargs: Additional arguments passed to CredentialStore
|
||||
|
||||
Returns:
|
||||
CredentialStore configured with Aden sync
|
||||
|
||||
Example:
|
||||
# Simple usage - just set ADEN_API_KEY env var
|
||||
store = CredentialStore.with_aden_sync()
|
||||
|
||||
# Get HubSpot token (auto-refreshed via Aden)
|
||||
token = store.get_key("hubspot", "access_token")
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
# Determine local storage path
|
||||
if local_path is None:
|
||||
local_path = str(Path.home() / ".hive" / "credentials")
|
||||
|
||||
local_storage = EncryptedFileStorage(base_path=local_path)
|
||||
|
||||
# Check if Aden is configured
|
||||
api_key = os.environ.get("ADEN_API_KEY")
|
||||
if not api_key:
|
||||
logger.info("ADEN_API_KEY not set, using local-only credential storage")
|
||||
return cls(storage=local_storage, **kwargs)
|
||||
|
||||
# Try to setup Aden sync
|
||||
try:
|
||||
from .aden import (
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
# Create Aden client
|
||||
client = AdenCredentialClient(AdenClientConfig(base_url=base_url))
|
||||
|
||||
# Create sync provider
|
||||
provider = AdenSyncProvider(client=client)
|
||||
|
||||
# Use cached storage for offline resilience
|
||||
cached_storage = AdenCachedStorage(
|
||||
local_storage=local_storage,
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=cache_ttl_seconds,
|
||||
)
|
||||
|
||||
store = cls(
|
||||
storage=cached_storage,
|
||||
providers=[provider],
|
||||
auto_refresh=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initial sync
|
||||
if auto_sync:
|
||||
synced = provider.sync_all(store)
|
||||
logger.info(f"Synced {synced} credentials from Aden server")
|
||||
|
||||
return store
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Aden components not available, using local storage")
|
||||
return cls(storage=local_storage, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup Aden sync: {e}. Using local storage.")
|
||||
return cls(storage=local_storage, **kwargs)
|
||||
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Template resolution system for credential injection.
|
||||
|
||||
This module handles {{cred.key}} patterns, enabling the bipartisan model
|
||||
where tools specify how credentials are used in HTTP requests.
|
||||
|
||||
Template Syntax:
|
||||
{{credential_id.key_name}} - Access specific key
|
||||
{{credential_id}} - Access default key (value, api_key, or access_token)
|
||||
|
||||
Examples:
|
||||
"Bearer {{github_oauth.access_token}}" -> "Bearer ghp_xxx"
|
||||
"X-API-Key: {{brave_search.api_key}}" -> "X-API-Key: BSAKxxx"
|
||||
"{{brave_search}}" -> "BSAKxxx" (uses default key)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import CredentialKeyNotFoundError, CredentialNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .store import CredentialStore
|
||||
|
||||
|
||||
class TemplateResolver:
|
||||
"""
|
||||
Resolves credential templates like {{cred.key}} into actual values.
|
||||
|
||||
Usage:
|
||||
resolver = TemplateResolver(credential_store)
|
||||
|
||||
# Resolve single template string
|
||||
auth_header = resolver.resolve("Bearer {{github_oauth.access_token}}")
|
||||
|
||||
# Resolve all headers at once
|
||||
headers = resolver.resolve_headers({
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}",
|
||||
"X-API-Key": "{{brave_search.api_key}}"
|
||||
})
|
||||
"""
|
||||
|
||||
# Matches {{credential_id}} or {{credential_id.key_name}}
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{([a-zA-Z0-9_-]+)(?:\.([a-zA-Z0-9_-]+))?\}\}")
|
||||
|
||||
def __init__(self, credential_store: CredentialStore):
|
||||
"""
|
||||
Initialize the template resolver.
|
||||
|
||||
Args:
|
||||
credential_store: The credential store to resolve references against
|
||||
"""
|
||||
self._store = credential_store
|
||||
|
||||
def resolve(self, template: str, fail_on_missing: bool = True) -> str:
|
||||
"""
|
||||
Resolve all credential references in a template string.
|
||||
|
||||
Args:
|
||||
template: String containing {{cred.key}} patterns
|
||||
fail_on_missing: If True, raise error on missing credentials
|
||||
|
||||
Returns:
|
||||
Template with all references replaced with actual values
|
||||
|
||||
Raises:
|
||||
CredentialNotFoundError: If credential doesn't exist and fail_on_missing=True
|
||||
CredentialKeyNotFoundError: If key doesn't exist in credential
|
||||
|
||||
Example:
|
||||
>>> resolver.resolve("Bearer {{github_oauth.access_token}}")
|
||||
"Bearer ghp_xxxxxxxxxxxx"
|
||||
"""
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
cred_id = match.group(1)
|
||||
key_name = match.group(2) # May be None
|
||||
|
||||
credential = self._store.get_credential(cred_id, refresh_if_needed=True)
|
||||
if credential is None:
|
||||
if fail_on_missing:
|
||||
raise CredentialNotFoundError(f"Credential '{cred_id}' not found")
|
||||
return match.group(0) # Return original template
|
||||
|
||||
# Get specific key or default
|
||||
if key_name:
|
||||
value = credential.get_key(key_name)
|
||||
if value is None:
|
||||
raise CredentialKeyNotFoundError(
|
||||
f"Key '{key_name}' not found in credential '{cred_id}'"
|
||||
)
|
||||
else:
|
||||
# Use default key
|
||||
value = credential.get_default_key()
|
||||
if value is None:
|
||||
raise CredentialKeyNotFoundError(f"Credential '{cred_id}' has no keys")
|
||||
|
||||
# Record usage
|
||||
credential.record_usage()
|
||||
|
||||
return value
|
||||
|
||||
return self.TEMPLATE_PATTERN.sub(replace_match, template)
|
||||
|
||||
def resolve_headers(
|
||||
self,
|
||||
header_templates: dict[str, str],
|
||||
fail_on_missing: bool = True,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Resolve templates in a headers dictionary.
|
||||
|
||||
Args:
|
||||
header_templates: Dict of header name to template value
|
||||
fail_on_missing: If True, raise error on missing credentials
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved to actual values
|
||||
|
||||
Example:
|
||||
>>> resolver.resolve_headers({
|
||||
... "Authorization": "Bearer {{github_oauth.access_token}}",
|
||||
... "X-API-Key": "{{brave_search.api_key}}"
|
||||
... })
|
||||
{"Authorization": "Bearer ghp_xxx", "X-API-Key": "BSAKxxx"}
|
||||
"""
|
||||
return {
|
||||
key: self.resolve(value, fail_on_missing) for key, value in header_templates.items()
|
||||
}
|
||||
|
||||
def resolve_params(
|
||||
self,
|
||||
param_templates: dict[str, str],
|
||||
fail_on_missing: bool = True,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Resolve templates in a query parameters dictionary.
|
||||
|
||||
Args:
|
||||
param_templates: Dict of param name to template value
|
||||
fail_on_missing: If True, raise error on missing credentials
|
||||
|
||||
Returns:
|
||||
Dict with all templates resolved to actual values
|
||||
"""
|
||||
return {key: self.resolve(value, fail_on_missing) for key, value in param_templates.items()}
|
||||
|
||||
def has_templates(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text contains any credential templates.
|
||||
|
||||
Args:
|
||||
text: String to check
|
||||
|
||||
Returns:
|
||||
True if text contains {{...}} patterns
|
||||
"""
|
||||
return bool(self.TEMPLATE_PATTERN.search(text))
|
||||
|
||||
def extract_references(self, text: str) -> list[tuple[str, str | None]]:
|
||||
"""
|
||||
Extract all credential references from text.
|
||||
|
||||
Args:
|
||||
text: String to extract references from
|
||||
|
||||
Returns:
|
||||
List of (credential_id, key_name) tuples.
|
||||
key_name is None if only credential_id was specified.
|
||||
|
||||
Example:
|
||||
>>> resolver.extract_references("{{github.token}} and {{brave_search.api_key}}")
|
||||
[("github", "token"), ("brave_search", "api_key")]
|
||||
"""
|
||||
return [(match.group(1), match.group(2)) for match in self.TEMPLATE_PATTERN.finditer(text)]
|
||||
|
||||
def validate_references(self, text: str) -> list[str]:
|
||||
"""
|
||||
Validate all credential references in text without resolving.
|
||||
|
||||
Args:
|
||||
text: String containing template references
|
||||
|
||||
Returns:
|
||||
List of error messages for invalid references.
|
||||
Empty list if all references are valid.
|
||||
"""
|
||||
errors = []
|
||||
references = self.extract_references(text)
|
||||
|
||||
for cred_id, key_name in references:
|
||||
credential = self._store.get_credential(cred_id, refresh_if_needed=False)
|
||||
|
||||
if credential is None:
|
||||
errors.append(f"Credential '{cred_id}' not found")
|
||||
continue
|
||||
|
||||
if key_name:
|
||||
if not credential.has_key(key_name):
|
||||
errors.append(f"Key '{key_name}' not found in credential '{cred_id}'")
|
||||
elif not credential.keys:
|
||||
errors.append(f"Credential '{cred_id}' has no keys")
|
||||
|
||||
return errors
|
||||
|
||||
def get_required_credentials(self, text: str) -> list[str]:
|
||||
"""
|
||||
Get list of credential IDs required by a template string.
|
||||
|
||||
Args:
|
||||
text: String containing template references
|
||||
|
||||
Returns:
|
||||
List of unique credential IDs referenced in the text
|
||||
"""
|
||||
references = self.extract_references(text)
|
||||
return list(dict.fromkeys(cred_id for cred_id, _ in references))
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for the credential store module."""
|
||||
@@ -0,0 +1,707 @@
|
||||
"""
|
||||
Comprehensive tests for the credential store module.
|
||||
|
||||
Tests cover:
|
||||
- Core models (CredentialObject, CredentialKey, CredentialUsageSpec)
|
||||
- Template resolution
|
||||
- Storage backends (InMemoryStorage, EnvVarStorage, EncryptedFileStorage)
|
||||
- Providers (StaticProvider, BearerTokenProvider)
|
||||
- Main CredentialStore
|
||||
- OAuth2 module
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from core.framework.credentials import (
|
||||
CompositeStorage,
|
||||
CredentialKey,
|
||||
CredentialKeyNotFoundError,
|
||||
CredentialNotFoundError,
|
||||
CredentialObject,
|
||||
CredentialStore,
|
||||
CredentialType,
|
||||
CredentialUsageSpec,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
InMemoryStorage,
|
||||
StaticProvider,
|
||||
TemplateResolver,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
class TestCredentialKey:
|
||||
"""Tests for CredentialKey model."""
|
||||
|
||||
def test_create_basic_key(self):
|
||||
"""Test creating a basic credential key."""
|
||||
key = CredentialKey(name="api_key", value=SecretStr("test-value"))
|
||||
assert key.name == "api_key"
|
||||
assert key.get_secret_value() == "test-value"
|
||||
assert key.expires_at is None
|
||||
assert not key.is_expired
|
||||
|
||||
def test_key_with_expiration(self):
|
||||
"""Test key with expiration time."""
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=future)
|
||||
assert not key.is_expired
|
||||
|
||||
def test_expired_key(self):
|
||||
"""Test that expired key is detected."""
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)
|
||||
assert key.is_expired
|
||||
|
||||
def test_key_with_metadata(self):
|
||||
"""Test key with metadata."""
|
||||
key = CredentialKey(
|
||||
name="token",
|
||||
value=SecretStr("xxx"),
|
||||
metadata={"client_id": "abc", "scope": "read"},
|
||||
)
|
||||
assert key.metadata["client_id"] == "abc"
|
||||
|
||||
|
||||
class TestCredentialObject:
|
||||
"""Tests for CredentialObject model."""
|
||||
|
||||
def test_create_simple_credential(self):
|
||||
"""Test creating a simple API key credential."""
|
||||
cred = CredentialObject(
|
||||
id="brave_search",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("test-key"))},
|
||||
)
|
||||
assert cred.id == "brave_search"
|
||||
assert cred.credential_type == CredentialType.API_KEY
|
||||
assert cred.get_key("api_key") == "test-key"
|
||||
|
||||
def test_create_multi_key_credential(self):
|
||||
"""Test creating a credential with multiple keys."""
|
||||
cred = CredentialObject(
|
||||
id="github_oauth",
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
keys={
|
||||
"access_token": CredentialKey(name="access_token", value=SecretStr("ghp_xxx")),
|
||||
"refresh_token": CredentialKey(name="refresh_token", value=SecretStr("ghr_xxx")),
|
||||
},
|
||||
)
|
||||
assert cred.get_key("access_token") == "ghp_xxx"
|
||||
assert cred.get_key("refresh_token") == "ghr_xxx"
|
||||
assert cred.get_key("nonexistent") is None
|
||||
|
||||
def test_set_key(self):
|
||||
"""Test setting a key on a credential."""
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
cred.set_key("new_key", "new_value")
|
||||
assert cred.get_key("new_key") == "new_value"
|
||||
|
||||
def test_set_key_with_expiration(self):
|
||||
"""Test setting a key with expiration."""
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
expires = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred.set_key("token", "xxx", expires_at=expires)
|
||||
assert cred.keys["token"].expires_at == expires
|
||||
|
||||
def test_needs_refresh(self):
|
||||
"""Test needs_refresh property."""
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
keys={"token": CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)},
|
||||
)
|
||||
assert cred.needs_refresh
|
||||
|
||||
def test_get_default_key(self):
|
||||
"""Test get_default_key returns appropriate default."""
|
||||
# With api_key
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("key-value"))},
|
||||
)
|
||||
assert cred.get_default_key() == "key-value"
|
||||
|
||||
# With access_token
|
||||
cred2 = CredentialObject(
|
||||
id="test",
|
||||
keys={
|
||||
"access_token": CredentialKey(name="access_token", value=SecretStr("token-value"))
|
||||
},
|
||||
)
|
||||
assert cred2.get_default_key() == "token-value"
|
||||
|
||||
def test_record_usage(self):
|
||||
"""Test recording credential usage."""
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
assert cred.use_count == 0
|
||||
assert cred.last_used is None
|
||||
|
||||
cred.record_usage()
|
||||
assert cred.use_count == 1
|
||||
assert cred.last_used is not None
|
||||
|
||||
|
||||
class TestCredentialUsageSpec:
|
||||
"""Tests for CredentialUsageSpec model."""
|
||||
|
||||
def test_create_usage_spec(self):
|
||||
"""Test creating a usage spec."""
|
||||
spec = CredentialUsageSpec(
|
||||
credential_id="brave_search",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Subscription-Token": "{{api_key}}"},
|
||||
)
|
||||
assert spec.credential_id == "brave_search"
|
||||
assert "api_key" in spec.required_keys
|
||||
assert "{{api_key}}" in spec.headers.values()
|
||||
|
||||
|
||||
class TestInMemoryStorage:
|
||||
"""Tests for InMemoryStorage."""
|
||||
|
||||
def test_save_and_load(self):
|
||||
"""Test saving and loading a credential."""
|
||||
storage = InMemoryStorage()
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
keys={"key": CredentialKey(name="key", value=SecretStr("value"))},
|
||||
)
|
||||
|
||||
storage.save(cred)
|
||||
loaded = storage.load("test")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == "test"
|
||||
assert loaded.get_key("key") == "value"
|
||||
|
||||
def test_load_nonexistent(self):
|
||||
"""Test loading a nonexistent credential."""
|
||||
storage = InMemoryStorage()
|
||||
assert storage.load("nonexistent") is None
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deleting a credential."""
|
||||
storage = InMemoryStorage()
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
storage.save(cred)
|
||||
|
||||
assert storage.delete("test")
|
||||
assert storage.load("test") is None
|
||||
assert not storage.delete("test")
|
||||
|
||||
def test_list_all(self):
|
||||
"""Test listing all credentials."""
|
||||
storage = InMemoryStorage()
|
||||
storage.save(CredentialObject(id="a", keys={}))
|
||||
storage.save(CredentialObject(id="b", keys={}))
|
||||
|
||||
ids = storage.list_all()
|
||||
assert "a" in ids
|
||||
assert "b" in ids
|
||||
|
||||
def test_exists(self):
|
||||
"""Test checking if credential exists."""
|
||||
storage = InMemoryStorage()
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
|
||||
assert storage.exists("test")
|
||||
assert not storage.exists("nonexistent")
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing all credentials."""
|
||||
storage = InMemoryStorage()
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
storage.clear()
|
||||
|
||||
assert storage.list_all() == []
|
||||
|
||||
|
||||
class TestEnvVarStorage:
|
||||
"""Tests for EnvVarStorage."""
|
||||
|
||||
def test_load_from_env(self):
|
||||
"""Test loading credential from environment variable."""
|
||||
with patch.dict(os.environ, {"TEST_API_KEY": "test-value"}):
|
||||
storage = EnvVarStorage(env_mapping={"test": "TEST_API_KEY"})
|
||||
cred = storage.load("test")
|
||||
|
||||
assert cred is not None
|
||||
assert cred.get_key("api_key") == "test-value"
|
||||
|
||||
def test_load_nonexistent(self):
|
||||
"""Test loading when env var is not set."""
|
||||
storage = EnvVarStorage(env_mapping={"test": "NONEXISTENT_VAR"})
|
||||
assert storage.load("test") is None
|
||||
|
||||
def test_default_env_var_pattern(self):
|
||||
"""Test default env var naming pattern."""
|
||||
with patch.dict(os.environ, {"MY_SERVICE_API_KEY": "value"}):
|
||||
storage = EnvVarStorage()
|
||||
cred = storage.load("my_service")
|
||||
|
||||
assert cred is not None
|
||||
assert cred.get_key("api_key") == "value"
|
||||
|
||||
def test_save_raises(self):
|
||||
"""Test that save raises NotImplementedError."""
|
||||
storage = EnvVarStorage()
|
||||
with pytest.raises(NotImplementedError):
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
|
||||
def test_delete_raises(self):
|
||||
"""Test that delete raises NotImplementedError."""
|
||||
storage = EnvVarStorage()
|
||||
with pytest.raises(NotImplementedError):
|
||||
storage.delete("test")
|
||||
|
||||
|
||||
class TestEncryptedFileStorage:
|
||||
"""Tests for EncryptedFileStorage."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create a temporary directory for tests."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self, temp_dir):
|
||||
"""Create EncryptedFileStorage for tests."""
|
||||
return EncryptedFileStorage(temp_dir)
|
||||
|
||||
def test_save_and_load(self, storage):
|
||||
"""Test saving and loading encrypted credential."""
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
credential_type=CredentialType.API_KEY,
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr("secret-value"))},
|
||||
)
|
||||
|
||||
storage.save(cred)
|
||||
loaded = storage.load("test")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == "test"
|
||||
assert loaded.get_key("api_key") == "secret-value"
|
||||
|
||||
def test_encryption_key_from_env(self, temp_dir):
|
||||
"""Test using encryption key from environment variable."""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
key = Fernet.generate_key().decode()
|
||||
with patch.dict(os.environ, {"HIVE_CREDENTIAL_KEY": key}):
|
||||
storage = EncryptedFileStorage(temp_dir)
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
storage.save(cred)
|
||||
|
||||
# Create new storage instance with same key
|
||||
storage2 = EncryptedFileStorage(temp_dir)
|
||||
loaded = storage2.load("test")
|
||||
assert loaded is not None
|
||||
assert loaded.get_key("k") == "v"
|
||||
|
||||
def test_list_all(self, storage):
|
||||
"""Test listing all credentials."""
|
||||
storage.save(CredentialObject(id="cred1", keys={}))
|
||||
storage.save(CredentialObject(id="cred2", keys={}))
|
||||
|
||||
ids = storage.list_all()
|
||||
assert "cred1" in ids
|
||||
assert "cred2" in ids
|
||||
|
||||
def test_delete(self, storage):
|
||||
"""Test deleting a credential."""
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
assert storage.delete("test")
|
||||
assert storage.load("test") is None
|
||||
|
||||
|
||||
class TestCompositeStorage:
|
||||
"""Tests for CompositeStorage."""
|
||||
|
||||
def test_read_from_primary(self):
|
||||
"""Test reading from primary storage."""
|
||||
primary = InMemoryStorage()
|
||||
primary.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("primary"))}
|
||||
)
|
||||
)
|
||||
|
||||
fallback = InMemoryStorage()
|
||||
fallback.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
|
||||
)
|
||||
)
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
cred = storage.load("test")
|
||||
|
||||
# Should get from primary
|
||||
assert cred.get_key("k") == "primary"
|
||||
|
||||
def test_fallback_when_not_in_primary(self):
|
||||
"""Test fallback when credential not in primary."""
|
||||
primary = InMemoryStorage()
|
||||
fallback = InMemoryStorage()
|
||||
fallback.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
|
||||
)
|
||||
)
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
cred = storage.load("test")
|
||||
|
||||
assert cred.get_key("k") == "fallback"
|
||||
|
||||
def test_write_to_primary_only(self):
|
||||
"""Test that writes go to primary only."""
|
||||
primary = InMemoryStorage()
|
||||
fallback = InMemoryStorage()
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
|
||||
assert primary.exists("test")
|
||||
assert not fallback.exists("test")
|
||||
|
||||
|
||||
class TestStaticProvider:
|
||||
"""Tests for StaticProvider."""
|
||||
|
||||
def test_provider_id(self):
|
||||
"""Test provider ID."""
|
||||
provider = StaticProvider()
|
||||
assert provider.provider_id == "static"
|
||||
|
||||
def test_supported_types(self):
|
||||
"""Test supported credential types."""
|
||||
provider = StaticProvider()
|
||||
assert CredentialType.API_KEY in provider.supported_types
|
||||
assert CredentialType.CUSTOM in provider.supported_types
|
||||
|
||||
def test_refresh_returns_unchanged(self):
|
||||
"""Test that refresh returns credential unchanged."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
|
||||
refreshed = provider.refresh(cred)
|
||||
assert refreshed.get_key("k") == "v"
|
||||
|
||||
def test_validate_with_keys(self):
|
||||
"""Test validation with keys present."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
|
||||
assert provider.validate(cred)
|
||||
|
||||
def test_validate_without_keys(self):
|
||||
"""Test validation without keys."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
|
||||
assert not provider.validate(cred)
|
||||
|
||||
def test_should_refresh(self):
|
||||
"""Test that static provider never needs refresh."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
|
||||
assert not provider.should_refresh(cred)
|
||||
|
||||
|
||||
class TestTemplateResolver:
|
||||
"""Tests for TemplateResolver."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
"""Create a test store with credentials."""
|
||||
return CredentialStore.for_testing(
|
||||
{
|
||||
"brave_search": {"api_key": "test-brave-key"},
|
||||
"github_oauth": {"access_token": "ghp_xxx", "refresh_token": "ghr_xxx"},
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def resolver(self, store):
|
||||
"""Create a resolver with the test store."""
|
||||
return TemplateResolver(store)
|
||||
|
||||
def test_resolve_simple(self, resolver):
|
||||
"""Test resolving a simple template."""
|
||||
result = resolver.resolve("Bearer {{github_oauth.access_token}}")
|
||||
assert result == "Bearer ghp_xxx"
|
||||
|
||||
def test_resolve_multiple(self, resolver):
|
||||
"""Test resolving multiple templates."""
|
||||
result = resolver.resolve("{{github_oauth.access_token}} and {{brave_search.api_key}}")
|
||||
assert "ghp_xxx" in result
|
||||
assert "test-brave-key" in result
|
||||
|
||||
def test_resolve_default_key(self, resolver):
|
||||
"""Test resolving credential without key specified."""
|
||||
result = resolver.resolve("Key: {{brave_search}}")
|
||||
assert "test-brave-key" in result
|
||||
|
||||
def test_resolve_headers(self, resolver):
|
||||
"""Test resolving headers dict."""
|
||||
headers = resolver.resolve_headers(
|
||||
{
|
||||
"Authorization": "Bearer {{github_oauth.access_token}}",
|
||||
"X-API-Key": "{{brave_search.api_key}}",
|
||||
}
|
||||
)
|
||||
assert headers["Authorization"] == "Bearer ghp_xxx"
|
||||
assert headers["X-API-Key"] == "test-brave-key"
|
||||
|
||||
def test_resolve_missing_credential(self, resolver):
|
||||
"""Test error on missing credential."""
|
||||
with pytest.raises(CredentialNotFoundError):
|
||||
resolver.resolve("{{nonexistent.key}}")
|
||||
|
||||
def test_resolve_missing_key(self, resolver):
|
||||
"""Test error on missing key."""
|
||||
with pytest.raises(CredentialKeyNotFoundError):
|
||||
resolver.resolve("{{github_oauth.nonexistent}}")
|
||||
|
||||
def test_has_templates(self, resolver):
|
||||
"""Test detecting templates in text."""
|
||||
assert resolver.has_templates("{{cred.key}}")
|
||||
assert resolver.has_templates("Bearer {{token}}")
|
||||
assert not resolver.has_templates("no templates here")
|
||||
|
||||
def test_extract_references(self, resolver):
|
||||
"""Test extracting credential references."""
|
||||
refs = resolver.extract_references("{{github.token}} and {{brave.key}}")
|
||||
assert ("github", "token") in refs
|
||||
assert ("brave", "key") in refs
|
||||
|
||||
|
||||
class TestCredentialStore:
|
||||
"""Tests for CredentialStore."""
|
||||
|
||||
def test_for_testing_factory(self):
|
||||
"""Test creating store for testing."""
|
||||
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
|
||||
|
||||
assert store.get("test") == "value"
|
||||
assert store.get_key("test", "api_key") == "value"
|
||||
|
||||
def test_get_credential(self):
|
||||
"""Test getting a credential."""
|
||||
store = CredentialStore.for_testing({"test": {"key": "value"}})
|
||||
|
||||
cred = store.get_credential("test")
|
||||
assert cred is not None
|
||||
assert cred.get_key("key") == "value"
|
||||
|
||||
def test_get_nonexistent(self):
|
||||
"""Test getting nonexistent credential."""
|
||||
store = CredentialStore.for_testing({})
|
||||
assert store.get_credential("nonexistent") is None
|
||||
assert store.get("nonexistent") is None
|
||||
|
||||
def test_save_and_load(self):
|
||||
"""Test saving and loading a credential."""
|
||||
store = CredentialStore.for_testing({})
|
||||
|
||||
cred = CredentialObject(id="new", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
store.save_credential(cred)
|
||||
|
||||
loaded = store.get_credential("new")
|
||||
assert loaded is not None
|
||||
assert loaded.get_key("k") == "v"
|
||||
|
||||
def test_delete_credential(self):
|
||||
"""Test deleting a credential."""
|
||||
store = CredentialStore.for_testing({"test": {"k": "v"}})
|
||||
|
||||
assert store.delete_credential("test")
|
||||
assert store.get_credential("test") is None
|
||||
|
||||
def test_list_credentials(self):
|
||||
"""Test listing all credentials."""
|
||||
store = CredentialStore.for_testing({"a": {"k": "v"}, "b": {"k": "v"}})
|
||||
|
||||
ids = store.list_credentials()
|
||||
assert "a" in ids
|
||||
assert "b" in ids
|
||||
|
||||
def test_is_available(self):
|
||||
"""Test checking credential availability."""
|
||||
store = CredentialStore.for_testing({"test": {"k": "v"}})
|
||||
|
||||
assert store.is_available("test")
|
||||
assert not store.is_available("nonexistent")
|
||||
|
||||
def test_resolve_templates(self):
|
||||
"""Test template resolution through store."""
|
||||
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
|
||||
|
||||
result = store.resolve("Key: {{test.api_key}}")
|
||||
assert result == "Key: value"
|
||||
|
||||
def test_resolve_headers(self):
|
||||
"""Test resolving headers through store."""
|
||||
store = CredentialStore.for_testing({"test": {"token": "xxx"}})
|
||||
|
||||
headers = store.resolve_headers({"Authorization": "Bearer {{test.token}}"})
|
||||
assert headers["Authorization"] == "Bearer xxx"
|
||||
|
||||
def test_register_provider(self):
|
||||
"""Test registering a provider."""
|
||||
store = CredentialStore.for_testing({})
|
||||
provider = StaticProvider()
|
||||
|
||||
store.register_provider(provider)
|
||||
assert store.get_provider("static") is provider
|
||||
|
||||
def test_register_usage_spec(self):
|
||||
"""Test registering a usage spec."""
|
||||
store = CredentialStore.for_testing({})
|
||||
spec = CredentialUsageSpec(
|
||||
credential_id="test",
|
||||
required_keys=["api_key"],
|
||||
headers={"X-Key": "{{api_key}}"},
|
||||
)
|
||||
|
||||
store.register_usage(spec)
|
||||
assert store.get_usage_spec("test") is spec
|
||||
|
||||
def test_validate_for_usage(self):
|
||||
"""Test validating credential for usage spec."""
|
||||
store = CredentialStore.for_testing({"test": {"api_key": "value"}})
|
||||
spec = CredentialUsageSpec(credential_id="test", required_keys=["api_key"])
|
||||
store.register_usage(spec)
|
||||
|
||||
errors = store.validate_for_usage("test")
|
||||
assert errors == []
|
||||
|
||||
def test_validate_for_usage_missing_key(self):
|
||||
"""Test validation with missing required key."""
|
||||
store = CredentialStore.for_testing({"test": {"other_key": "value"}})
|
||||
spec = CredentialUsageSpec(credential_id="test", required_keys=["api_key"])
|
||||
store.register_usage(spec)
|
||||
|
||||
errors = store.validate_for_usage("test")
|
||||
assert "api_key" in errors[0]
|
||||
|
||||
def test_caching(self):
|
||||
"""Test that credentials are cached."""
|
||||
storage = InMemoryStorage()
|
||||
store = CredentialStore(storage=storage, cache_ttl_seconds=60)
|
||||
|
||||
storage.save(
|
||||
CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
)
|
||||
|
||||
# First load
|
||||
store.get_credential("test")
|
||||
|
||||
# Delete from storage
|
||||
storage.delete("test")
|
||||
|
||||
# Should still get from cache
|
||||
cred2 = store.get_credential("test")
|
||||
assert cred2 is not None
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing the cache."""
|
||||
storage = InMemoryStorage()
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
storage.save(CredentialObject(id="test", keys={}))
|
||||
store.get_credential("test") # Cache it
|
||||
|
||||
storage.delete("test")
|
||||
store.clear_cache()
|
||||
|
||||
# Should not find in cache now
|
||||
assert store.get_credential("test") is None
|
||||
|
||||
|
||||
class TestOAuth2Module:
|
||||
"""Tests for OAuth2 module."""
|
||||
|
||||
def test_oauth2_token_from_response(self):
|
||||
"""Test creating OAuth2Token from token response."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Token
|
||||
|
||||
response = {
|
||||
"access_token": "xxx",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "yyy",
|
||||
"scope": "read write",
|
||||
}
|
||||
|
||||
token = OAuth2Token.from_token_response(response)
|
||||
assert token.access_token == "xxx"
|
||||
assert token.token_type == "Bearer"
|
||||
assert token.refresh_token == "yyy"
|
||||
assert token.scope == "read write"
|
||||
assert token.expires_at is not None
|
||||
|
||||
def test_token_is_expired(self):
|
||||
"""Test token expiration check."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Token
|
||||
|
||||
# Not expired
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
token = OAuth2Token(access_token="xxx", expires_at=future)
|
||||
assert not token.is_expired
|
||||
|
||||
# Expired
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
expired_token = OAuth2Token(access_token="xxx", expires_at=past)
|
||||
assert expired_token.is_expired
|
||||
|
||||
def test_token_can_refresh(self):
|
||||
"""Test token refresh capability check."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Token
|
||||
|
||||
with_refresh = OAuth2Token(access_token="xxx", refresh_token="yyy")
|
||||
assert with_refresh.can_refresh
|
||||
|
||||
without_refresh = OAuth2Token(access_token="xxx")
|
||||
assert not without_refresh.can_refresh
|
||||
|
||||
def test_oauth2_config_validation(self):
|
||||
"""Test OAuth2Config validation."""
|
||||
from core.framework.credentials.oauth2 import OAuth2Config, TokenPlacement
|
||||
|
||||
# Valid config
|
||||
config = OAuth2Config(
|
||||
token_url="https://example.com/token", client_id="id", client_secret="secret"
|
||||
)
|
||||
assert config.token_url == "https://example.com/token"
|
||||
|
||||
# Missing token_url
|
||||
with pytest.raises(ValueError):
|
||||
OAuth2Config(token_url="")
|
||||
|
||||
# HEADER_CUSTOM without custom_header_name
|
||||
with pytest.raises(ValueError):
|
||||
OAuth2Config(
|
||||
token_url="https://example.com/token",
|
||||
token_placement=TokenPlacement.HEADER_CUSTOM,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
HashiCorp Vault integration for the credential store.
|
||||
|
||||
This module provides enterprise-grade secret management through
|
||||
HashiCorp Vault integration.
|
||||
|
||||
Quick Start:
|
||||
from core.framework.credentials import CredentialStore
|
||||
from core.framework.credentials.vault import HashiCorpVaultStorage
|
||||
|
||||
# Configure Vault storage
|
||||
storage = HashiCorpVaultStorage(
|
||||
url="https://vault.example.com:8200",
|
||||
# token read from VAULT_TOKEN env var
|
||||
mount_point="secret",
|
||||
path_prefix="hive/agents/prod"
|
||||
)
|
||||
|
||||
# Create credential store with Vault backend
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Use normally - credentials are stored in Vault
|
||||
credential = store.get_credential("my_api")
|
||||
|
||||
Requirements:
|
||||
pip install hvac
|
||||
|
||||
Authentication:
|
||||
Set the VAULT_TOKEN environment variable or pass the token directly:
|
||||
|
||||
export VAULT_TOKEN="hvs.xxxxxxxxxxxxx"
|
||||
|
||||
For production, consider using Vault auth methods:
|
||||
- Kubernetes auth
|
||||
- AppRole auth
|
||||
- AWS IAM auth
|
||||
|
||||
Vault Configuration:
|
||||
Ensure KV v2 secrets engine is enabled:
|
||||
|
||||
vault secrets enable -path=secret kv-v2
|
||||
|
||||
Grant appropriate policies:
|
||||
|
||||
path "secret/data/hive/credentials/*" {
|
||||
capabilities = ["create", "read", "update", "delete", "list"]
|
||||
}
|
||||
path "secret/metadata/hive/credentials/*" {
|
||||
capabilities = ["list", "delete"]
|
||||
}
|
||||
"""
|
||||
|
||||
from .hashicorp import HashiCorpVaultStorage
|
||||
|
||||
__all__ = ["HashiCorpVaultStorage"]
|
||||
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
HashiCorp Vault storage adapter.
|
||||
|
||||
Provides integration with HashiCorp Vault for enterprise secret management.
|
||||
Requires the 'hvac' package: pip install hvac
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from ..models import CredentialKey, CredentialObject, CredentialType
|
||||
from ..storage import CredentialStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HashiCorpVaultStorage(CredentialStorage):
|
||||
"""
|
||||
HashiCorp Vault storage adapter.
|
||||
|
||||
Features:
|
||||
- KV v2 secrets engine support
|
||||
- Namespace support (Enterprise)
|
||||
- Automatic secret versioning
|
||||
- Audit logging via Vault
|
||||
|
||||
The adapter stores credentials in Vault's KV v2 secrets engine with
|
||||
the following structure:
|
||||
|
||||
{mount_point}/data/{path_prefix}/{credential_id}
|
||||
└── data:
|
||||
├── _type: "oauth2"
|
||||
├── access_token: "xxx"
|
||||
├── refresh_token: "yyy"
|
||||
├── _expires_access_token: "2024-01-26T12:00:00"
|
||||
└── _provider_id: "oauth2"
|
||||
|
||||
Example:
|
||||
storage = HashiCorpVaultStorage(
|
||||
url="https://vault.example.com:8200",
|
||||
token="hvs.xxx", # Or use VAULT_TOKEN env var
|
||||
mount_point="secret",
|
||||
path_prefix="hive/credentials"
|
||||
)
|
||||
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Credentials are now stored in Vault
|
||||
store.save_credential(credential)
|
||||
credential = store.get_credential("my_api")
|
||||
|
||||
Authentication:
|
||||
The adapter uses token-based authentication. The token can be provided:
|
||||
1. Directly via the 'token' parameter
|
||||
2. Via the VAULT_TOKEN environment variable
|
||||
|
||||
For production, consider using:
|
||||
- Kubernetes auth method
|
||||
- AppRole auth method
|
||||
- AWS IAM auth method
|
||||
|
||||
Requirements:
|
||||
pip install hvac
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
token: str | None = None,
|
||||
mount_point: str = "secret",
|
||||
path_prefix: str = "hive/credentials",
|
||||
namespace: str | None = None,
|
||||
verify_ssl: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize Vault storage.
|
||||
|
||||
Args:
|
||||
url: Vault server URL (e.g., https://vault.example.com:8200)
|
||||
token: Vault token. If None, reads from VAULT_TOKEN env var
|
||||
mount_point: KV secrets engine mount point (default: "secret")
|
||||
path_prefix: Path prefix for all credentials
|
||||
namespace: Vault namespace (Enterprise feature)
|
||||
verify_ssl: Whether to verify SSL certificates
|
||||
|
||||
Raises:
|
||||
ImportError: If hvac is not installed
|
||||
ValueError: If authentication fails
|
||||
"""
|
||||
try:
|
||||
import hvac
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"HashiCorp Vault support requires 'hvac'. Install with: pip install hvac"
|
||||
) from e
|
||||
|
||||
self._url = url
|
||||
self._token = token or os.environ.get("VAULT_TOKEN")
|
||||
self._mount = mount_point
|
||||
self._prefix = path_prefix
|
||||
self._namespace = namespace
|
||||
|
||||
if not self._token:
|
||||
raise ValueError(
|
||||
"Vault token required. Set VAULT_TOKEN env var or pass token parameter."
|
||||
)
|
||||
|
||||
self._client = hvac.Client(
|
||||
url=url,
|
||||
token=self._token,
|
||||
namespace=namespace,
|
||||
verify=verify_ssl,
|
||||
)
|
||||
|
||||
if not self._client.is_authenticated():
|
||||
raise ValueError("Vault authentication failed. Check token and server URL.")
|
||||
|
||||
logger.info(f"Connected to HashiCorp Vault at {url}")
|
||||
|
||||
def _path(self, credential_id: str) -> str:
|
||||
"""Build Vault path for credential."""
|
||||
# Sanitize credential_id
|
||||
safe_id = credential_id.replace("/", "_").replace("\\", "_")
|
||||
return f"{self._prefix}/{safe_id}"
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Save credential to Vault KV v2."""
|
||||
path = self._path(credential.id)
|
||||
data = self._serialize_for_vault(credential)
|
||||
|
||||
try:
|
||||
self._client.secrets.kv.v2.create_or_update_secret(
|
||||
path=path,
|
||||
secret=data,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
logger.debug(f"Saved credential '{credential.id}' to Vault at {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save credential '{credential.id}' to Vault: {e}")
|
||||
raise
|
||||
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from Vault."""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
data = response["data"]["data"]
|
||||
return self._deserialize_from_vault(credential_id, data)
|
||||
except Exception as e:
|
||||
# Check if it's a "not found" error
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str or "404" in error_str:
|
||||
logger.debug(f"Credential '{credential_id}' not found in Vault")
|
||||
return None
|
||||
logger.error(f"Failed to load credential '{credential_id}' from Vault: {e}")
|
||||
raise
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Delete credential from Vault (all versions)."""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
self._client.secrets.kv.v2.delete_metadata_and_all_versions(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
logger.debug(f"Deleted credential '{credential_id}' from Vault")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str or "404" in error_str:
|
||||
return False
|
||||
logger.error(f"Failed to delete credential '{credential_id}' from Vault: {e}")
|
||||
raise
|
||||
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credentials under the prefix."""
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.list_secrets(
|
||||
path=self._prefix,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
keys = response.get("data", {}).get("keys", [])
|
||||
# Remove trailing slashes from folder names
|
||||
return [k.rstrip("/") for k in keys]
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str or "404" in error_str:
|
||||
return []
|
||||
logger.error(f"Failed to list credentials from Vault: {e}")
|
||||
raise
|
||||
|
||||
def exists(self, credential_id: str) -> bool:
|
||||
"""Check if credential exists in Vault."""
|
||||
try:
|
||||
path = self._path(credential_id)
|
||||
self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _serialize_for_vault(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""Convert credential to Vault secret format."""
|
||||
data: dict[str, Any] = {
|
||||
"_type": credential.credential_type.value,
|
||||
}
|
||||
|
||||
if credential.provider_id:
|
||||
data["_provider_id"] = credential.provider_id
|
||||
|
||||
if credential.description:
|
||||
data["_description"] = credential.description
|
||||
|
||||
if credential.auto_refresh:
|
||||
data["_auto_refresh"] = "true"
|
||||
|
||||
# Store each key
|
||||
for key_name, key in credential.keys.items():
|
||||
data[key_name] = key.get_secret_value()
|
||||
|
||||
if key.expires_at:
|
||||
data[f"_expires_{key_name}"] = key.expires_at.isoformat()
|
||||
|
||||
if key.metadata:
|
||||
data[f"_metadata_{key_name}"] = str(key.metadata)
|
||||
|
||||
return data
|
||||
|
||||
def _deserialize_from_vault(self, credential_id: str, data: dict[str, Any]) -> CredentialObject:
|
||||
"""Reconstruct credential from Vault secret."""
|
||||
# Extract metadata fields
|
||||
cred_type = CredentialType(data.pop("_type", "api_key"))
|
||||
provider_id = data.pop("_provider_id", None)
|
||||
description = data.pop("_description", "")
|
||||
auto_refresh = data.pop("_auto_refresh", "") == "true"
|
||||
|
||||
# Build keys dict
|
||||
keys: dict[str, CredentialKey] = {}
|
||||
|
||||
# Find all non-metadata keys
|
||||
key_names = [k for k in data.keys() if not k.startswith("_")]
|
||||
|
||||
for key_name in key_names:
|
||||
value = data[key_name]
|
||||
|
||||
# Check for expiration
|
||||
expires_at = None
|
||||
expires_key = f"_expires_{key_name}"
|
||||
if expires_key in data:
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(data[expires_key])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Check for metadata
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata_key = f"_metadata_{key_name}"
|
||||
if metadata_key in data:
|
||||
try:
|
||||
import ast
|
||||
|
||||
metadata = ast.literal_eval(data[metadata_key])
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
keys[key_name] = CredentialKey(
|
||||
name=key_name,
|
||||
value=SecretStr(value),
|
||||
expires_at=expires_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return CredentialObject(
|
||||
id=credential_id,
|
||||
credential_type=cred_type,
|
||||
keys=keys,
|
||||
provider_id=provider_id,
|
||||
description=description,
|
||||
auto_refresh=auto_refresh,
|
||||
)
|
||||
|
||||
# --- Vault-Specific Operations ---
|
||||
|
||||
def get_secret_metadata(self, credential_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get Vault metadata for a secret (version info, timestamps, etc.).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
|
||||
Returns:
|
||||
Metadata dict or None if not found
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.read_secret_metadata(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return response.get("data", {})
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def soft_delete(self, credential_id: str, versions: list[int] | None = None) -> bool:
|
||||
"""
|
||||
Soft delete specific versions (can be recovered).
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
versions: Version numbers to delete. If None, deletes latest.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
if versions:
|
||||
self._client.secrets.kv.v2.delete_secret_versions(
|
||||
path=path,
|
||||
versions=versions,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
else:
|
||||
self._client.secrets.kv.v2.delete_latest_version_of_secret(
|
||||
path=path,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Soft delete failed for '{credential_id}': {e}")
|
||||
return False
|
||||
|
||||
def undelete(self, credential_id: str, versions: list[int]) -> bool:
|
||||
"""
|
||||
Recover soft-deleted versions.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
versions: Version numbers to recover
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
self._client.secrets.kv.v2.undelete_secret_versions(
|
||||
path=path,
|
||||
versions=versions,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Undelete failed for '{credential_id}': {e}")
|
||||
return False
|
||||
|
||||
def load_version(self, credential_id: str, version: int) -> CredentialObject | None:
|
||||
"""
|
||||
Load a specific version of a credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential identifier
|
||||
version: Version number to load
|
||||
|
||||
Returns:
|
||||
CredentialObject or None
|
||||
"""
|
||||
path = self._path(credential_id)
|
||||
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
version=version,
|
||||
mount_point=self._mount,
|
||||
)
|
||||
data = response["data"]["data"]
|
||||
return self._deserialize_from_vault(credential_id, data)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -1,32 +1,32 @@
|
||||
"""Graph structures: Goals, Nodes, Edges, and Flexible Execution."""
|
||||
|
||||
from framework.graph.goal import Goal, SuccessCriterion, Constraint, GoalStatus
|
||||
from framework.graph.node import NodeSpec, NodeContext, NodeResult, NodeProtocol
|
||||
from framework.graph.edge import EdgeSpec, EdgeCondition
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_eval, safe_exec
|
||||
from framework.graph.edge import EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.flexible_executor import ExecutorConfig, FlexibleGraphExecutor
|
||||
from framework.graph.goal import Constraint, Goal, GoalStatus, SuccessCriterion
|
||||
from framework.graph.judge import HybridJudge, create_default_judge
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult, NodeSpec
|
||||
|
||||
# Flexible execution (Worker-Judge pattern)
|
||||
from framework.graph.plan import (
|
||||
Plan,
|
||||
PlanStep,
|
||||
ActionSpec,
|
||||
ActionType,
|
||||
StepStatus,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
EvaluationRule,
|
||||
PlanExecutionResult,
|
||||
ExecutionStatus,
|
||||
load_export,
|
||||
# HITL (Human-in-the-loop)
|
||||
ApprovalDecision,
|
||||
ApprovalRequest,
|
||||
ApprovalResult,
|
||||
EvaluationRule,
|
||||
ExecutionStatus,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
Plan,
|
||||
PlanExecutionResult,
|
||||
PlanStep,
|
||||
StepStatus,
|
||||
load_export,
|
||||
)
|
||||
from framework.graph.judge import HybridJudge, create_default_judge
|
||||
from framework.graph.worker_node import WorkerNode, StepExecutionResult
|
||||
from framework.graph.flexible_executor import FlexibleGraphExecutor, ExecutorConfig
|
||||
from framework.graph.code_sandbox import CodeSandbox, safe_exec, safe_eval
|
||||
from framework.graph.worker_node import StepExecutionResult, WorkerNode
|
||||
|
||||
__all__ = [
|
||||
# Goal
|
||||
@@ -42,6 +42,7 @@ __all__ = [
|
||||
# Edge
|
||||
"EdgeSpec",
|
||||
"EdgeCondition",
|
||||
"GraphSpec",
|
||||
# Executor (fixed graph)
|
||||
"GraphExecutor",
|
||||
# Plan (flexible execution)
|
||||
|
||||
@@ -13,11 +13,11 @@ Security measures:
|
||||
"""
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import signal
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
# Safe builtins whitelist
|
||||
SAFE_BUILTINS = {
|
||||
@@ -25,7 +25,6 @@ SAFE_BUILTINS = {
|
||||
"True": True,
|
||||
"False": False,
|
||||
"None": None,
|
||||
|
||||
# Type constructors
|
||||
"bool": bool,
|
||||
"int": int,
|
||||
@@ -36,7 +35,6 @@ SAFE_BUILTINS = {
|
||||
"set": set,
|
||||
"tuple": tuple,
|
||||
"frozenset": frozenset,
|
||||
|
||||
# Basic functions
|
||||
"abs": abs,
|
||||
"all": all,
|
||||
@@ -97,22 +95,26 @@ BLOCKED_AST_NODES = {
|
||||
|
||||
class CodeSandboxError(Exception):
|
||||
"""Error during sandboxed code execution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(CodeSandboxError):
|
||||
"""Code execution timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SecurityError(CodeSandboxError):
|
||||
"""Code contains potentially dangerous operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxResult:
|
||||
"""Result of sandboxed code execution."""
|
||||
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
@@ -134,6 +136,7 @@ class RestrictedImporter:
|
||||
|
||||
if name not in self._cache:
|
||||
import importlib
|
||||
|
||||
self._cache[name] = importlib.import_module(name)
|
||||
|
||||
return self._cache[name]
|
||||
@@ -161,9 +164,8 @@ class CodeValidator:
|
||||
for node in ast.walk(tree):
|
||||
# Check for blocked node types
|
||||
if type(node) in self.blocked_nodes:
|
||||
issues.append(
|
||||
f"Blocked operation: {type(node).__name__} at line {getattr(node, 'lineno', '?')}"
|
||||
)
|
||||
lineno = getattr(node, "lineno", "?")
|
||||
issues.append(f"Blocked operation: {type(node).__name__} at line {lineno}")
|
||||
|
||||
# Check for dangerous attribute access
|
||||
if isinstance(node, ast.Attribute):
|
||||
@@ -212,11 +214,12 @@ class CodeSandbox:
|
||||
@contextmanager
|
||||
def _timeout_context(self, seconds: int):
|
||||
"""Context manager for timeout enforcement."""
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError(f"Code execution timed out after {seconds} seconds")
|
||||
|
||||
# Only works on Unix-like systems
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
if hasattr(signal, "SIGALRM"):
|
||||
old_handler = signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(seconds)
|
||||
try:
|
||||
@@ -275,6 +278,7 @@ class CodeSandbox:
|
||||
|
||||
# Capture stdout
|
||||
import io
|
||||
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = captured_stdout = io.StringIO()
|
||||
|
||||
@@ -296,11 +300,7 @@ class CodeSandbox:
|
||||
|
||||
# Also extract any new variables (not in inputs or builtins)
|
||||
for key, value in namespace.items():
|
||||
if (
|
||||
key not in inputs
|
||||
and key not in self.safe_builtins
|
||||
and not key.startswith("_")
|
||||
):
|
||||
if key not in inputs and key not in self.safe_builtins and not key.startswith("_"):
|
||||
extracted[key] = value
|
||||
|
||||
return SandboxResult(
|
||||
|
||||
@@ -11,9 +11,10 @@ our edges can be created dynamically by a Builder agent based on the goal.
|
||||
|
||||
Edge Types:
|
||||
- always: Always traverse after source completes
|
||||
- always: Always traverse after source completes
|
||||
- on_success: Traverse only if source succeeds
|
||||
- on_failure: Traverse only if source fails
|
||||
- conditional: Traverse based on expression evaluation
|
||||
- conditional: Traverse based on expression evaluation (SAFE SUBSET ONLY)
|
||||
- llm_decide: Let LLM decide based on goal and context (goal-aware routing)
|
||||
|
||||
The llm_decide condition is particularly powerful for goal-driven agents,
|
||||
@@ -21,19 +22,22 @@ allowing the LLM to evaluate whether proceeding along an edge makes sense
|
||||
given the current goal, context, and execution state.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
|
||||
class EdgeCondition(str, Enum):
|
||||
"""When an edge should be traversed."""
|
||||
ALWAYS = "always" # Always after source completes
|
||||
ON_SUCCESS = "on_success" # Only if source succeeds
|
||||
ON_FAILURE = "on_failure" # Only if source fails
|
||||
CONDITIONAL = "conditional" # Based on expression
|
||||
LLM_DECIDE = "llm_decide" # Let LLM decide based on goal and context
|
||||
|
||||
ALWAYS = "always" # Always after source completes
|
||||
ON_SUCCESS = "on_success" # Only if source succeeds
|
||||
ON_FAILURE = "on_failure" # Only if source fails
|
||||
CONDITIONAL = "conditional" # Based on expression
|
||||
LLM_DECIDE = "llm_decide" # Let LLM decide based on goal and context
|
||||
|
||||
|
||||
class EdgeSpec(BaseModel):
|
||||
@@ -68,6 +72,7 @@ class EdgeSpec(BaseModel):
|
||||
description="Only filter if results need refinement to meet goal",
|
||||
)
|
||||
"""
|
||||
|
||||
id: str
|
||||
source: str = Field(description="Source node ID")
|
||||
target: str = Field(description="Target node ID")
|
||||
@@ -76,20 +81,17 @@ class EdgeSpec(BaseModel):
|
||||
condition: EdgeCondition = EdgeCondition.ALWAYS
|
||||
condition_expr: str | None = Field(
|
||||
default=None,
|
||||
description="Expression for CONDITIONAL edges, e.g., 'output.confidence > 0.8'"
|
||||
description="Expression for CONDITIONAL edges, e.g., 'output.confidence > 0.8'",
|
||||
)
|
||||
|
||||
# Data flow
|
||||
input_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Map source outputs to target inputs: {target_key: source_key}"
|
||||
description="Map source outputs to target inputs: {target_key: source_key}",
|
||||
)
|
||||
|
||||
# Priority for multiple outgoing edges
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description="Higher priority edges are evaluated first"
|
||||
)
|
||||
priority: int = Field(default=0, description="Higher priority edges are evaluated first")
|
||||
|
||||
# Metadata
|
||||
description: str = ""
|
||||
@@ -164,17 +166,18 @@ class EdgeSpec(BaseModel):
|
||||
"output": output,
|
||||
"memory": memory,
|
||||
"result": output.get("result"),
|
||||
"true": True, # Allow lowercase true/false in conditions
|
||||
"true": True, # Allow lowercase true/false in conditions
|
||||
"false": False,
|
||||
**memory, # Unpack memory keys directly into context
|
||||
}
|
||||
|
||||
try:
|
||||
# Safe evaluation (in production, use a proper expression evaluator)
|
||||
return bool(eval(self.condition_expr, {"__builtins__": {}}, context))
|
||||
# Safe evaluation using AST-based whitelist
|
||||
return bool(safe_eval(self.condition_expr, context))
|
||||
except Exception as e:
|
||||
# Log the error for debugging
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f" ⚠ Condition evaluation failed: {self.condition_expr}")
|
||||
logger.warning(f" Error: {e}")
|
||||
@@ -235,7 +238,8 @@ Respond with ONLY a JSON object:
|
||||
|
||||
# Parse response
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
|
||||
|
||||
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
proceed = data.get("proceed", False)
|
||||
@@ -243,6 +247,7 @@ Respond with ONLY a JSON object:
|
||||
|
||||
# Log the decision (using basic print for now)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f" 🤔 LLM routing decision: {'PROCEED' if proceed else 'SKIP'}")
|
||||
logger.info(f" Reason: {reasoning}")
|
||||
@@ -252,6 +257,7 @@ Respond with ONLY a JSON object:
|
||||
except Exception as e:
|
||||
# Fallback: proceed on success
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f" ⚠ LLM routing failed, defaulting to on_success: {e}")
|
||||
return source_success
|
||||
@@ -304,28 +310,24 @@ class AsyncEntryPointSpec(BaseModel):
|
||||
isolation_level="shared",
|
||||
)
|
||||
"""
|
||||
|
||||
id: str = Field(description="Unique identifier for this entry point")
|
||||
name: str = Field(description="Human-readable name")
|
||||
entry_node: str = Field(description="Node ID to start execution from")
|
||||
trigger_type: str = Field(
|
||||
default="manual",
|
||||
description="How this entry point is triggered: webhook, api, timer, event, manual"
|
||||
description="How this entry point is triggered: webhook, api, timer, event, manual",
|
||||
)
|
||||
trigger_config: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Trigger-specific configuration (e.g., webhook URL, timer interval)"
|
||||
description="Trigger-specific configuration (e.g., webhook URL, timer interval)",
|
||||
)
|
||||
isolation_level: str = Field(
|
||||
default="shared",
|
||||
description="State isolation: isolated, shared, or synchronized"
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description="Execution priority (higher = more priority)"
|
||||
default="shared", description="State isolation: isolated, shared, or synchronized"
|
||||
)
|
||||
priority: int = Field(default=0, description="Execution priority (higher = more priority)")
|
||||
max_concurrent: int = Field(
|
||||
default=10,
|
||||
description="Maximum concurrent executions for this entry point"
|
||||
default=10, description="Maximum concurrent executions for this entry point"
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
@@ -370,6 +372,7 @@ class GraphSpec(BaseModel):
|
||||
edges=[...],
|
||||
)
|
||||
"""
|
||||
|
||||
id: str
|
||||
goal_id: str
|
||||
version: str = "1.0.0"
|
||||
@@ -378,46 +381,43 @@ class GraphSpec(BaseModel):
|
||||
entry_node: str = Field(description="ID of the first node to execute")
|
||||
entry_points: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Named entry points for resuming execution. Format: {name: node_id}"
|
||||
description="Named entry points for resuming execution. Format: {name: node_id}",
|
||||
)
|
||||
async_entry_points: list[AsyncEntryPointSpec] = Field(
|
||||
default_factory=list,
|
||||
description="Asynchronous entry points for concurrent execution streams (used with AgentRuntime)"
|
||||
description=(
|
||||
"Asynchronous entry points for concurrent execution streams (used with AgentRuntime)"
|
||||
),
|
||||
)
|
||||
terminal_nodes: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of nodes that end execution"
|
||||
default_factory=list, description="IDs of nodes that end execution"
|
||||
)
|
||||
pause_nodes: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of nodes that pause execution for HITL input"
|
||||
default_factory=list, description="IDs of nodes that pause execution for HITL input"
|
||||
)
|
||||
|
||||
# Components
|
||||
nodes: list[Any] = Field( # NodeSpec, but avoiding circular import
|
||||
default_factory=list,
|
||||
description="All node specifications"
|
||||
)
|
||||
edges: list[EdgeSpec] = Field(
|
||||
default_factory=list,
|
||||
description="All edge specifications"
|
||||
default_factory=list, description="All node specifications"
|
||||
)
|
||||
edges: list[EdgeSpec] = Field(default_factory=list, description="All edge specifications")
|
||||
|
||||
# Shared memory keys
|
||||
memory_keys: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Keys available in shared memory"
|
||||
default_factory=list, description="Keys available in shared memory"
|
||||
)
|
||||
|
||||
# Default LLM settings
|
||||
default_model: str = "claude-haiku-4-5-20251001"
|
||||
max_tokens: int = 1024
|
||||
|
||||
# Cleanup LLM for JSON extraction fallback (fast/cheap model preferred)
|
||||
# If not set, uses CEREBRAS_API_KEY -> cerebras/llama-3.3-70b or
|
||||
# ANTHROPIC_API_KEY -> claude-3-5-haiku as fallback
|
||||
cleanup_llm_model: str | None = None
|
||||
|
||||
# Execution limits
|
||||
max_steps: int = Field(
|
||||
default=100,
|
||||
description="Maximum node executions before timeout"
|
||||
)
|
||||
max_steps: int = Field(default=100, description="Maximum node executions before timeout")
|
||||
max_retries_per_node: int = 3
|
||||
|
||||
# Metadata
|
||||
@@ -453,6 +453,42 @@ class GraphSpec(BaseModel):
|
||||
"""Get all edges entering a node."""
|
||||
return [e for e in self.edges if e.target == node_id]
|
||||
|
||||
def detect_fan_out_nodes(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Detect nodes that fan-out to multiple targets.
|
||||
|
||||
A fan-out occurs when a node has multiple outgoing edges with the same
|
||||
condition (typically ON_SUCCESS) that should execute in parallel.
|
||||
|
||||
Returns:
|
||||
Dict mapping source_node_id -> list of parallel target_node_ids
|
||||
"""
|
||||
fan_outs: dict[str, list[str]] = {}
|
||||
for node in self.nodes:
|
||||
outgoing = self.get_outgoing_edges(node.id)
|
||||
# Fan-out: multiple edges with ON_SUCCESS condition
|
||||
success_edges = [e for e in outgoing if e.condition == EdgeCondition.ON_SUCCESS]
|
||||
if len(success_edges) > 1:
|
||||
fan_outs[node.id] = [e.target for e in success_edges]
|
||||
return fan_outs
|
||||
|
||||
def detect_fan_in_nodes(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Detect nodes that receive from multiple sources (fan-in / convergence).
|
||||
|
||||
A fan-in occurs when a node has multiple incoming edges, meaning
|
||||
it should wait for all predecessor branches to complete.
|
||||
|
||||
Returns:
|
||||
Dict mapping target_node_id -> list of source_node_ids
|
||||
"""
|
||||
fan_ins: dict[str, list[str]] = {}
|
||||
for node in self.nodes:
|
||||
incoming = self.get_incoming_edges(node.id)
|
||||
if len(incoming) > 1:
|
||||
fan_ins[node.id] = [e.source for e in incoming]
|
||||
return fan_ins
|
||||
|
||||
def get_entry_point(self, session_state: dict | None = None) -> str:
|
||||
"""
|
||||
Get the appropriate entry point based on session state.
|
||||
@@ -504,7 +540,8 @@ class GraphSpec(BaseModel):
|
||||
# Check entry node exists
|
||||
if not self.get_node(entry_point.entry_node):
|
||||
errors.append(
|
||||
f"Async entry point '{entry_point.id}' references missing node '{entry_point.entry_node}'"
|
||||
f"Async entry point '{entry_point.id}' references "
|
||||
f"missing node '{entry_point.entry_node}'"
|
||||
)
|
||||
|
||||
# Validate isolation level
|
||||
@@ -562,11 +599,13 @@ class GraphSpec(BaseModel):
|
||||
|
||||
for node in self.nodes:
|
||||
if node.id not in reachable:
|
||||
# Skip this error if the node is a pause node, entry point target, or async entry point
|
||||
# (pause/resume architecture and async entry points make these reachable)
|
||||
if (node.id in self.pause_nodes or
|
||||
node.id in self.entry_points.values() or
|
||||
node.id in async_entry_nodes):
|
||||
# Skip if node is a pause node, entry point target, or async entry
|
||||
# (pause/resume architecture and async entry points make reachable)
|
||||
if (
|
||||
node.id in self.pause_nodes
|
||||
or node.id in self.entry_points.values()
|
||||
or node.id in async_entry_nodes
|
||||
):
|
||||
continue
|
||||
errors.append(f"Node '{node.id}' is unreachable from entry")
|
||||
|
||||
|
||||
@@ -9,31 +9,34 @@ The executor:
|
||||
5. Returns the final result
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.graph.edge import EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import (
|
||||
NodeSpec,
|
||||
NodeContext,
|
||||
NodeResult,
|
||||
NodeProtocol,
|
||||
SharedMemory,
|
||||
LLMNode,
|
||||
RouterNode,
|
||||
FunctionNode,
|
||||
LLMNode,
|
||||
NodeContext,
|
||||
NodeProtocol,
|
||||
NodeResult,
|
||||
NodeSpec,
|
||||
RouterNode,
|
||||
SharedMemory,
|
||||
)
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
|
||||
from framework.graph.validator import OutputValidator
|
||||
from framework.graph.output_cleaner import OutputCleaner, CleansingConfig
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result of executing a graph."""
|
||||
|
||||
success: bool
|
||||
output: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
@@ -45,6 +48,35 @@ class ExecutionResult:
|
||||
session_state: dict[str, Any] = field(default_factory=dict) # State to resume from
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelBranch:
|
||||
"""Tracks a single branch in parallel fan-out execution."""
|
||||
|
||||
branch_id: str
|
||||
node_id: str
|
||||
edge: EdgeSpec
|
||||
result: "NodeResult | None" = None
|
||||
status: str = "pending" # pending, running, completed, failed
|
||||
retry_count: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelExecutionConfig:
|
||||
"""Configuration for parallel execution behavior."""
|
||||
|
||||
# Error handling: "fail_all" cancels all on first failure,
|
||||
# "continue_others" lets remaining branches complete,
|
||||
# "wait_all" waits for all and reports all failures
|
||||
on_branch_failure: str = "fail_all"
|
||||
|
||||
# Memory conflict handling when branches write same key
|
||||
memory_conflict_strategy: str = "last_wins" # "last_wins", "first_wins", "error"
|
||||
|
||||
# Timeout per branch in seconds
|
||||
branch_timeout_seconds: float = 300.0
|
||||
|
||||
|
||||
class GraphExecutor:
|
||||
"""
|
||||
Executes agent graphs.
|
||||
@@ -73,6 +105,8 @@ class GraphExecutor:
|
||||
node_registry: dict[str, NodeProtocol] | None = None,
|
||||
approval_callback: Callable | None = None,
|
||||
cleansing_config: CleansingConfig | None = None,
|
||||
enable_parallel_execution: bool = True,
|
||||
parallel_config: ParallelExecutionConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -85,6 +119,8 @@ class GraphExecutor:
|
||||
node_registry: Custom node implementations by ID
|
||||
approval_callback: Optional callback for human-in-the-loop approval
|
||||
cleansing_config: Optional output cleansing configuration
|
||||
enable_parallel_execution: Enable parallel fan-out execution (default True)
|
||||
parallel_config: Configuration for parallel execution behavior
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -102,6 +138,10 @@ class GraphExecutor:
|
||||
llm_provider=llm,
|
||||
)
|
||||
|
||||
# Parallel execution settings
|
||||
self.enable_parallel_execution = enable_parallel_execution
|
||||
self._parallel_config = parallel_config or ParallelExecutionConfig()
|
||||
|
||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||
"""
|
||||
Validate that all tools declared by nodes are available.
|
||||
@@ -116,14 +156,15 @@ class GraphExecutor:
|
||||
if node.tools:
|
||||
missing = set(node.tools) - available_tool_names
|
||||
if missing:
|
||||
available = sorted(available_tool_names) if available_tool_names else "none"
|
||||
errors.append(
|
||||
f"Node '{node.name}' (id={node.id}) requires tools {sorted(missing)} "
|
||||
f"but they are not registered. Available tools: {sorted(available_tool_names) if available_tool_names else 'none'}"
|
||||
f"Node '{node.name}' (id={node.id}) requires tools "
|
||||
f"{sorted(missing)} but they are not registered. "
|
||||
f"Available tools: {available}"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
@@ -159,7 +200,10 @@ class GraphExecutor:
|
||||
self.logger.error(f" • {err}")
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Missing tools: {'; '.join(tool_errors)}. Register tools via ToolRegistry or remove tool declarations from nodes.",
|
||||
error=(
|
||||
f"Missing tools: {'; '.join(tool_errors)}. "
|
||||
"Register tools via ToolRegistry or remove tool declarations from nodes."
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize execution state
|
||||
@@ -167,10 +211,18 @@ class GraphExecutor:
|
||||
|
||||
# Restore session state if provided
|
||||
if session_state and "memory" in session_state:
|
||||
# Restore memory from previous session
|
||||
for key, value in session_state["memory"].items():
|
||||
memory.write(key, value)
|
||||
self.logger.info(f"📥 Restored session state with {len(session_state['memory'])} memory keys")
|
||||
memory_data = session_state["memory"]
|
||||
# [RESTORED] Type safety check
|
||||
if not isinstance(memory_data, dict):
|
||||
self.logger.warning(
|
||||
f"⚠️ Invalid memory data type in session state: "
|
||||
f"{type(memory_data).__name__}, expected dict"
|
||||
)
|
||||
else:
|
||||
# Restore memory from previous session
|
||||
for key, value in memory_data.items():
|
||||
memory.write(key, value)
|
||||
self.logger.info(f"📥 Restored session state with {len(memory_data)} memory keys")
|
||||
|
||||
# Write new input data to memory (each key individually)
|
||||
if input_data:
|
||||
@@ -227,6 +279,7 @@ class GraphExecutor:
|
||||
memory=memory,
|
||||
goal=goal,
|
||||
input_data=input_data or {},
|
||||
max_tokens=graph.max_tokens,
|
||||
)
|
||||
|
||||
# Log actual input data being read
|
||||
@@ -242,7 +295,7 @@ class GraphExecutor:
|
||||
self.logger.info(f" {key}: {value_str}")
|
||||
|
||||
# Get or create node implementation
|
||||
node_impl = self._get_node_implementation(node_spec)
|
||||
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||
|
||||
# Validate inputs
|
||||
validation_errors = node_impl.validate_input(ctx)
|
||||
@@ -264,6 +317,7 @@ class GraphExecutor:
|
||||
output=result.output,
|
||||
expected_keys=node_spec.output_keys,
|
||||
check_hallucination=True,
|
||||
nullable_keys=node_spec.nullable_output_keys,
|
||||
)
|
||||
if not validation.success:
|
||||
self.logger.error(f" ✗ Output validation failed: {validation.error}")
|
||||
@@ -276,7 +330,10 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
if result.success:
|
||||
self.logger.info(f" ✓ Success (tokens: {result.tokens_used}, latency: {result.latency_ms}ms)")
|
||||
self.logger.info(
|
||||
f" ✓ Success (tokens: {result.tokens_used}, "
|
||||
f"latency: {result.latency_ms}ms)"
|
||||
)
|
||||
|
||||
# Generate and log human-readable summary
|
||||
summary = result.to_summary(node_spec)
|
||||
@@ -299,28 +356,55 @@ class GraphExecutor:
|
||||
# Handle failure
|
||||
if not result.success:
|
||||
# Track retries per node
|
||||
node_retry_counts[current_node_id] = node_retry_counts.get(current_node_id, 0) + 1
|
||||
node_retry_counts[current_node_id] = (
|
||||
node_retry_counts.get(current_node_id, 0) + 1
|
||||
)
|
||||
|
||||
if node_retry_counts[current_node_id] < node_spec.max_retries:
|
||||
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
|
||||
max_retries = getattr(node_spec, "max_retries", 3)
|
||||
|
||||
if node_retry_counts[current_node_id] < max_retries:
|
||||
# Retry - don't increment steps for retries
|
||||
steps -= 1
|
||||
self.logger.info(f" ↻ Retrying ({node_retry_counts[current_node_id]}/{node_spec.max_retries})...")
|
||||
|
||||
# --- EXPONENTIAL BACKOFF ---
|
||||
retry_count = node_retry_counts[current_node_id]
|
||||
# Backoff formula: 1.0 * (2^(retry - 1)) -> 1s, 2s, 4s...
|
||||
delay = 1.0 * (2 ** (retry_count - 1))
|
||||
self.logger.info(f" Using backoff: Sleeping {delay}s before retry...")
|
||||
await asyncio.sleep(delay)
|
||||
# --------------------------------------
|
||||
|
||||
self.logger.info(
|
||||
f" ↻ Retrying ({node_retry_counts[current_node_id]}/{max_retries})..."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - fail the execution
|
||||
self.logger.error(f" ✗ Max retries ({node_spec.max_retries}) exceeded for node {current_node_id}")
|
||||
self.logger.error(
|
||||
f" ✗ Max retries ({max_retries}) exceeded for node {current_node_id}"
|
||||
)
|
||||
self.runtime.report_problem(
|
||||
severity="critical",
|
||||
description=f"Node {current_node_id} failed after {node_spec.max_retries} attempts: {result.error}",
|
||||
description=(
|
||||
f"Node {current_node_id} failed after "
|
||||
f"{max_retries} attempts: {result.error}"
|
||||
),
|
||||
)
|
||||
self.runtime.end_run(
|
||||
success=False,
|
||||
output_data=memory.read_all(),
|
||||
narrative=f"Failed at {node_spec.name} after {node_spec.max_retries} retries: {result.error}",
|
||||
narrative=(
|
||||
f"Failed at {node_spec.name} after "
|
||||
f"{max_retries} retries: {result.error}"
|
||||
),
|
||||
)
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Node '{node_spec.name}' failed after {node_spec.max_retries} attempts: {result.error}",
|
||||
error=(
|
||||
f"Node '{node_spec.name}' failed after "
|
||||
f"{max_retries} attempts: {result.error}"
|
||||
),
|
||||
output=memory.read_all(),
|
||||
steps_executed=steps,
|
||||
total_tokens=total_tokens,
|
||||
@@ -368,8 +452,8 @@ class GraphExecutor:
|
||||
self.logger.info(f" → Router directing to: {result.next_node}")
|
||||
current_node_id = result.next_node
|
||||
else:
|
||||
# Follow edges
|
||||
next_node = self._follow_edges(
|
||||
# Get all traversable edges for fan-out detection
|
||||
traversable_edges = self._get_all_traversable_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
@@ -377,12 +461,59 @@ class GraphExecutor:
|
||||
result=result,
|
||||
memory=memory,
|
||||
)
|
||||
if next_node is None:
|
||||
|
||||
if not traversable_edges:
|
||||
self.logger.info(" → No more edges, ending execution")
|
||||
break # No valid edge, end execution
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
current_node_id = next_node
|
||||
|
||||
# Check for fan-out (multiple traversable edges)
|
||||
if self.enable_parallel_execution and len(traversable_edges) > 1:
|
||||
# Find convergence point (fan-in node)
|
||||
targets = [e.target for e in traversable_edges]
|
||||
fan_in_node = self._find_convergence_node(graph, targets)
|
||||
|
||||
# Execute branches in parallel
|
||||
(
|
||||
_branch_results,
|
||||
branch_tokens,
|
||||
branch_latency,
|
||||
) = await self._execute_parallel_branches(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
edges=traversable_edges,
|
||||
memory=memory,
|
||||
source_result=result,
|
||||
source_node_spec=node_spec,
|
||||
path=path,
|
||||
)
|
||||
|
||||
total_tokens += branch_tokens
|
||||
total_latency += branch_latency
|
||||
|
||||
# Continue from fan-in node
|
||||
if fan_in_node:
|
||||
self.logger.info(f" ⑃ Fan-in: converging at {fan_in_node}")
|
||||
current_node_id = fan_in_node
|
||||
else:
|
||||
# No convergence point - branches are terminal
|
||||
self.logger.info(" → Parallel branches completed (no convergence)")
|
||||
break
|
||||
else:
|
||||
# Sequential: follow single edge (existing logic via _follow_edges)
|
||||
next_node = self._follow_edges(
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
current_node_id=current_node_id,
|
||||
current_node_spec=node_spec,
|
||||
result=result,
|
||||
memory=memory,
|
||||
)
|
||||
if next_node is None:
|
||||
self.logger.info(" → No more edges, ending execution")
|
||||
break
|
||||
next_spec = graph.get_node(next_node)
|
||||
self.logger.info(f" → Next: {next_spec.name if next_spec else next_node}")
|
||||
current_node_id = next_node
|
||||
|
||||
# Update input_data for next node
|
||||
input_data = result.output
|
||||
@@ -433,6 +564,7 @@ class GraphExecutor:
|
||||
memory: SharedMemory,
|
||||
goal: Goal,
|
||||
input_data: dict[str, Any],
|
||||
max_tokens: int = 4096,
|
||||
) -> NodeContext:
|
||||
"""Build execution context for a node."""
|
||||
# Filter tools to those available to this node
|
||||
@@ -456,12 +588,15 @@ class GraphExecutor:
|
||||
available_tools=available_tools,
|
||||
goal_context=goal.to_prompt_context(),
|
||||
goal=goal, # Pass Goal object for LLM-powered routers
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Valid node types - no ambiguous "llm" type allowed
|
||||
VALID_NODE_TYPES = {"llm_tool_use", "llm_generate", "router", "function", "human_input"}
|
||||
|
||||
def _get_node_implementation(self, node_spec: NodeSpec) -> NodeProtocol:
|
||||
def _get_node_implementation(
|
||||
self, node_spec: NodeSpec, cleanup_llm_model: str | None = None
|
||||
) -> NodeProtocol:
|
||||
"""Get or create a node implementation."""
|
||||
# Check registry first
|
||||
if node_spec.id in self.node_registry:
|
||||
@@ -482,10 +617,18 @@ class GraphExecutor:
|
||||
f"Node '{node_spec.id}' is type 'llm_tool_use' but declares no tools. "
|
||||
"Either add tools to the node or change type to 'llm_generate'."
|
||||
)
|
||||
return LLMNode(tool_executor=self.tool_executor, require_tools=True)
|
||||
return LLMNode(
|
||||
tool_executor=self.tool_executor,
|
||||
require_tools=True,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "llm_generate":
|
||||
return LLMNode(tool_executor=None, require_tools=False)
|
||||
return LLMNode(
|
||||
tool_executor=None,
|
||||
require_tools=False,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
if node_spec.node_type == "router":
|
||||
return RouterNode()
|
||||
@@ -493,13 +636,16 @@ class GraphExecutor:
|
||||
if node_spec.node_type == "function":
|
||||
# Function nodes need explicit registration
|
||||
raise RuntimeError(
|
||||
f"Function node '{node_spec.id}' not registered. "
|
||||
"Register with node_registry."
|
||||
f"Function node '{node_spec.id}' not registered. Register with node_registry."
|
||||
)
|
||||
|
||||
if node_spec.node_type == "human_input":
|
||||
# Human input nodes are handled specially by HITL mechanism
|
||||
return LLMNode(tool_executor=None, require_tools=False)
|
||||
return LLMNode(
|
||||
tool_executor=None,
|
||||
require_tools=False,
|
||||
cleanup_llm_model=cleanup_llm_model,
|
||||
)
|
||||
|
||||
# Should never reach here due to validation above
|
||||
raise RuntimeError(f"Unhandled node type: {node_spec.node_type}")
|
||||
@@ -539,9 +685,7 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
if not validation.valid:
|
||||
self.logger.warning(
|
||||
f"⚠ Output validation failed: {validation.errors}"
|
||||
)
|
||||
self.logger.warning(f"⚠ Output validation failed: {validation.errors}")
|
||||
|
||||
# Clean the output
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
@@ -554,9 +698,9 @@ class GraphExecutor:
|
||||
# Update result with cleaned output
|
||||
result.output = cleaned_output
|
||||
|
||||
# Write cleaned output back to memory
|
||||
# Write cleaned output back to memory (skip validation for LLM output)
|
||||
for key, value in cleaned_output.items():
|
||||
memory.write(key, value)
|
||||
memory.write(key, value, validate=False)
|
||||
|
||||
# Revalidate
|
||||
revalidation = self.output_cleaner.validate_output(
|
||||
@@ -573,15 +717,249 @@ class GraphExecutor:
|
||||
)
|
||||
# Continue anyway if fallback_to_raw is True
|
||||
|
||||
# Map inputs
|
||||
# Map inputs (skip validation for processed LLM output)
|
||||
mapped = edge.map_inputs(result.output, memory.read_all())
|
||||
for key, value in mapped.items():
|
||||
memory.write(key, value)
|
||||
memory.write(key, value, validate=False)
|
||||
|
||||
return edge.target
|
||||
|
||||
return None
|
||||
|
||||
def _get_all_traversable_edges(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
current_node_id: str,
|
||||
current_node_spec: Any,
|
||||
result: NodeResult,
|
||||
memory: SharedMemory,
|
||||
) -> list[EdgeSpec]:
|
||||
"""
|
||||
Get ALL edges that should be traversed (for fan-out detection).
|
||||
|
||||
Unlike _follow_edges which returns the first match, this returns
|
||||
all matching edges to enable parallel execution.
|
||||
"""
|
||||
edges = graph.get_outgoing_edges(current_node_id)
|
||||
traversable = []
|
||||
|
||||
for edge in edges:
|
||||
target_node_spec = graph.get_node(edge.target)
|
||||
if edge.should_traverse(
|
||||
source_success=result.success,
|
||||
source_output=result.output,
|
||||
memory=memory.read_all(),
|
||||
llm=self.llm,
|
||||
goal=goal,
|
||||
source_node_name=current_node_spec.name if current_node_spec else current_node_id,
|
||||
target_node_name=target_node_spec.name if target_node_spec else edge.target,
|
||||
):
|
||||
traversable.append(edge)
|
||||
|
||||
return traversable
|
||||
|
||||
def _find_convergence_node(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
parallel_targets: list[str],
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the common target node where parallel branches converge (fan-in).
|
||||
|
||||
Args:
|
||||
graph: The graph specification
|
||||
parallel_targets: List of node IDs that are running in parallel
|
||||
|
||||
Returns:
|
||||
Node ID where all branches converge, or None if no convergence
|
||||
"""
|
||||
# Get all nodes that parallel branches lead to
|
||||
next_nodes: dict[str, int] = {} # node_id -> count of branches leading to it
|
||||
|
||||
for target in parallel_targets:
|
||||
outgoing = graph.get_outgoing_edges(target)
|
||||
for edge in outgoing:
|
||||
next_nodes[edge.target] = next_nodes.get(edge.target, 0) + 1
|
||||
|
||||
# Convergence node is where ALL branches lead
|
||||
for node_id, count in next_nodes.items():
|
||||
if count == len(parallel_targets):
|
||||
return node_id
|
||||
|
||||
# Fallback: return most common target if any
|
||||
if next_nodes:
|
||||
return max(next_nodes.keys(), key=lambda k: next_nodes[k])
|
||||
|
||||
return None
|
||||
|
||||
async def _execute_parallel_branches(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
goal: Goal,
|
||||
edges: list[EdgeSpec],
|
||||
memory: SharedMemory,
|
||||
source_result: NodeResult,
|
||||
source_node_spec: Any,
|
||||
path: list[str],
|
||||
) -> tuple[dict[str, NodeResult], int, int]:
|
||||
"""
|
||||
Execute multiple branches in parallel using asyncio.gather.
|
||||
|
||||
Args:
|
||||
graph: The graph specification
|
||||
goal: The execution goal
|
||||
edges: List of edges to follow in parallel
|
||||
memory: Shared memory instance
|
||||
source_result: Result from the source node
|
||||
source_node_spec: Spec of the source node
|
||||
path: Execution path list to update
|
||||
|
||||
Returns:
|
||||
Tuple of (branch_results dict, total_tokens, total_latency)
|
||||
"""
|
||||
branches: dict[str, ParallelBranch] = {}
|
||||
|
||||
# Create branches for each edge
|
||||
for edge in edges:
|
||||
branch_id = f"{edge.source}_to_{edge.target}"
|
||||
branches[branch_id] = ParallelBranch(
|
||||
branch_id=branch_id,
|
||||
node_id=edge.target,
|
||||
edge=edge,
|
||||
)
|
||||
|
||||
self.logger.info(f" ⑂ Fan-out: executing {len(branches)} branches in parallel")
|
||||
for branch in branches.values():
|
||||
target_spec = graph.get_node(branch.node_id)
|
||||
self.logger.info(f" • {target_spec.name if target_spec else branch.node_id}")
|
||||
|
||||
async def execute_single_branch(
|
||||
branch: ParallelBranch,
|
||||
) -> tuple[ParallelBranch, NodeResult | Exception]:
|
||||
"""Execute a single branch with retry logic."""
|
||||
node_spec = graph.get_node(branch.node_id)
|
||||
if node_spec is None:
|
||||
branch.status = "failed"
|
||||
branch.error = f"Node {branch.node_id} not found in graph"
|
||||
return branch, RuntimeError(branch.error)
|
||||
branch.status = "running"
|
||||
|
||||
try:
|
||||
# Validate and clean output before mapping inputs (same as _follow_edges)
|
||||
if self.cleansing_config.enabled and node_spec:
|
||||
validation = self.output_cleaner.validate_output(
|
||||
output=source_result.output,
|
||||
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||
target_node_spec=node_spec,
|
||||
)
|
||||
|
||||
if not validation.valid:
|
||||
self.logger.warning(
|
||||
f"⚠ Output validation failed for branch "
|
||||
f"{branch.node_id}: {validation.errors}"
|
||||
)
|
||||
cleaned_output = self.output_cleaner.clean_output(
|
||||
output=source_result.output,
|
||||
source_node_id=source_node_spec.id if source_node_spec else "unknown",
|
||||
target_node_spec=node_spec,
|
||||
validation_errors=validation.errors,
|
||||
)
|
||||
# Write cleaned output to memory
|
||||
for key, value in cleaned_output.items():
|
||||
await memory.write_async(key, value)
|
||||
|
||||
# Map inputs via edge
|
||||
mapped = branch.edge.map_inputs(source_result.output, memory.read_all())
|
||||
for key, value in mapped.items():
|
||||
await memory.write_async(key, value)
|
||||
|
||||
# Execute with retries
|
||||
last_result = None
|
||||
for attempt in range(node_spec.max_retries):
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
ctx = self._build_context(node_spec, memory, goal, mapped, graph.max_tokens)
|
||||
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||
|
||||
self.logger.info(
|
||||
f" ▶ Branch {node_spec.name}: executing (attempt {attempt + 1})"
|
||||
)
|
||||
result = await node_impl.execute(ctx)
|
||||
last_result = result
|
||||
|
||||
if result.success:
|
||||
# Write outputs to shared memory using async write
|
||||
for key, value in result.output.items():
|
||||
await memory.write_async(key, value)
|
||||
|
||||
branch.result = result
|
||||
branch.status = "completed"
|
||||
self.logger.info(
|
||||
f" ✓ Branch {node_spec.name}: success "
|
||||
f"(tokens: {result.tokens_used}, latency: {result.latency_ms}ms)"
|
||||
)
|
||||
return branch, result
|
||||
|
||||
self.logger.warning(
|
||||
f" ↻ Branch {node_spec.name}: "
|
||||
f"retry {attempt + 1}/{node_spec.max_retries}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
branch.status = "failed"
|
||||
branch.error = last_result.error if last_result else "Unknown error"
|
||||
branch.result = last_result
|
||||
self.logger.error(
|
||||
f" ✗ Branch {node_spec.name}: "
|
||||
f"failed after {node_spec.max_retries} attempts"
|
||||
)
|
||||
return branch, last_result
|
||||
|
||||
except Exception as e:
|
||||
branch.status = "failed"
|
||||
branch.error = str(e)
|
||||
self.logger.error(f" ✗ Branch {branch.node_id}: exception - {e}")
|
||||
return branch, e
|
||||
|
||||
# Execute all branches concurrently
|
||||
tasks = [execute_single_branch(b) for b in branches.values()]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# Process results
|
||||
total_tokens = 0
|
||||
total_latency = 0
|
||||
branch_results: dict[str, NodeResult] = {}
|
||||
failed_branches: list[ParallelBranch] = []
|
||||
|
||||
for branch, result in results:
|
||||
path.append(branch.node_id)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
failed_branches.append(branch)
|
||||
elif result is None or not result.success:
|
||||
failed_branches.append(branch)
|
||||
else:
|
||||
total_tokens += result.tokens_used
|
||||
total_latency += result.latency_ms
|
||||
branch_results[branch.branch_id] = result
|
||||
|
||||
# Handle failures based on config
|
||||
if failed_branches:
|
||||
failed_names = [graph.get_node(b.node_id).name for b in failed_branches]
|
||||
if self._parallel_config.on_branch_failure == "fail_all":
|
||||
raise RuntimeError(f"Parallel execution failed: branches {failed_names} failed")
|
||||
elif self._parallel_config.on_branch_failure == "continue_others":
|
||||
self.logger.warning(
|
||||
f"⚠ Some branches failed ({failed_names}), continuing with successful ones"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f" ⑃ Fan-out complete: {len(branch_results)}/{len(branches)} branches succeeded"
|
||||
)
|
||||
return branch_results, total_tokens, total_latency
|
||||
|
||||
def register_node(self, node_id: str, implementation: NodeProtocol) -> None:
|
||||
"""Register a custom node implementation."""
|
||||
self.node_registry[node_id] = implementation
|
||||
|
||||
@@ -15,28 +15,29 @@ using a Worker-Judge loop:
|
||||
This keeps planning external while execution/evaluation is internal.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.graph.code_sandbox import CodeSandbox
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.judge import HybridJudge, create_default_judge
|
||||
from framework.graph.plan import (
|
||||
Plan,
|
||||
PlanStep,
|
||||
PlanExecutionResult,
|
||||
ExecutionStatus,
|
||||
StepStatus,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
ApprovalDecision,
|
||||
ApprovalRequest,
|
||||
ApprovalResult,
|
||||
ApprovalDecision,
|
||||
ExecutionStatus,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
Plan,
|
||||
PlanExecutionResult,
|
||||
PlanStep,
|
||||
StepStatus,
|
||||
)
|
||||
from framework.graph.judge import HybridJudge, create_default_judge
|
||||
from framework.graph.worker_node import WorkerNode, StepExecutionResult
|
||||
from framework.graph.code_sandbox import CodeSandbox
|
||||
from framework.graph.worker_node import StepExecutionResult, WorkerNode
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# Type alias for approval callback
|
||||
ApprovalCallback = Callable[[ApprovalRequest], ApprovalResult]
|
||||
@@ -45,6 +46,7 @@ ApprovalCallback = Callable[[ApprovalRequest], ApprovalResult]
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Configuration for FlexibleGraphExecutor."""
|
||||
|
||||
max_retries_per_step: int = 3
|
||||
max_total_steps: int = 100
|
||||
timeout_seconds: int = 300
|
||||
@@ -165,7 +167,10 @@ class FlexibleGraphExecutor:
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback="No executable steps available but plan not complete. Check dependencies.",
|
||||
feedback=(
|
||||
"No executable steps available but plan not complete. "
|
||||
"Check dependencies."
|
||||
),
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
@@ -174,7 +179,8 @@ class FlexibleGraphExecutor:
|
||||
# Execute next step (for now, sequential; could be parallel)
|
||||
step = ready_steps[0]
|
||||
# Debug: show ready steps
|
||||
# print(f" [DEBUG] Ready steps: {[s.id for s in ready_steps]}, executing: {step.id}")
|
||||
# ready_ids = [s.id for s in ready_steps]
|
||||
# print(f" [DEBUG] Ready steps: {ready_ids}, executing: {step.id}")
|
||||
|
||||
# APPROVAL CHECK - before execution
|
||||
if step.requires_approval:
|
||||
@@ -360,7 +366,10 @@ class FlexibleGraphExecutor:
|
||||
status=ExecutionStatus.NEEDS_REPLAN,
|
||||
plan=plan,
|
||||
context=context,
|
||||
feedback=f"Step '{step.id}' failed after {step.attempts} attempts: {judgment.feedback}",
|
||||
feedback=(
|
||||
f"Step '{step.id}' failed after {step.attempts} attempts: "
|
||||
f"{judgment.feedback}"
|
||||
),
|
||||
steps_executed=steps_executed,
|
||||
total_tokens=total_tokens,
|
||||
total_latency=total_latency,
|
||||
@@ -450,12 +459,17 @@ class FlexibleGraphExecutor:
|
||||
preview_parts.append(f"Tool: {step.action.tool_name}")
|
||||
if step.action.tool_args:
|
||||
import json
|
||||
|
||||
args_preview = json.dumps(step.action.tool_args, indent=2, default=str)
|
||||
if len(args_preview) > 500:
|
||||
args_preview = args_preview[:500] + "..."
|
||||
preview_parts.append(f"Args: {args_preview}")
|
||||
elif step.action.prompt:
|
||||
prompt_preview = step.action.prompt[:300] + "..." if len(step.action.prompt) > 300 else step.action.prompt
|
||||
prompt_preview = (
|
||||
step.action.prompt[:300] + "..."
|
||||
if len(step.action.prompt) > 300
|
||||
else step.action.prompt
|
||||
)
|
||||
preview_parts.append(f"Prompt: {prompt_preview}")
|
||||
|
||||
# Include step inputs resolved from context (what will be sent/used)
|
||||
|
||||
@@ -12,20 +12,21 @@ Goals are:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GoalStatus(str, Enum):
|
||||
"""Lifecycle status of a goal."""
|
||||
DRAFT = "draft" # Being defined
|
||||
READY = "ready" # Ready for agent creation
|
||||
ACTIVE = "active" # Has an agent graph, can execute
|
||||
COMPLETED = "completed" # Achieved
|
||||
FAILED = "failed" # Could not be achieved
|
||||
SUSPENDED = "suspended" # Paused for revision
|
||||
|
||||
DRAFT = "draft" # Being defined
|
||||
READY = "ready" # Ready for agent creation
|
||||
ACTIVE = "active" # Has an agent graph, can execute
|
||||
COMPLETED = "completed" # Achieved
|
||||
FAILED = "failed" # Could not be achieved
|
||||
SUSPENDED = "suspended" # Paused for revision
|
||||
|
||||
|
||||
class SuccessCriterion(BaseModel):
|
||||
@@ -37,22 +38,14 @@ class SuccessCriterion(BaseModel):
|
||||
- Measurable: Can be evaluated programmatically or by LLM
|
||||
- Achievable: Within the agent's capabilities
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str = Field(
|
||||
description="Human-readable description of what success looks like"
|
||||
)
|
||||
description: str = Field(description="Human-readable description of what success looks like")
|
||||
metric: str = Field(
|
||||
description="How to measure: 'output_contains', 'output_equals', 'llm_judge', 'custom'"
|
||||
)
|
||||
target: Any = Field(
|
||||
description="The target value or condition"
|
||||
)
|
||||
weight: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Relative importance (0-1)"
|
||||
)
|
||||
target: Any = Field(description="The target value or condition")
|
||||
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="Relative importance (0-1)")
|
||||
met: bool = False
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
@@ -66,18 +59,17 @@ class Constraint(BaseModel):
|
||||
- Hard: Violation means failure
|
||||
- Soft: Violation is discouraged but allowed
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
constraint_type: str = Field(
|
||||
description="Type: 'hard' (must not violate) or 'soft' (prefer not to violate)"
|
||||
)
|
||||
category: str = Field(
|
||||
default="general",
|
||||
description="Category: 'time', 'cost', 'safety', 'scope', 'quality'"
|
||||
default="general", description="Category: 'time', 'cost', 'safety', 'scope', 'quality'"
|
||||
)
|
||||
check: str = Field(
|
||||
default="",
|
||||
description="How to check: expression, function name, or 'llm_judge'"
|
||||
default="", description="How to check: expression, function name, or 'llm_judge'"
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
@@ -119,6 +111,7 @@ class Goal(BaseModel):
|
||||
]
|
||||
)
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
@@ -133,23 +126,19 @@ class Goal(BaseModel):
|
||||
# Context for the agent
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional context: domain knowledge, user preferences, etc."
|
||||
description="Additional context: domain knowledge, user preferences, etc.",
|
||||
)
|
||||
|
||||
# Capabilities required
|
||||
required_capabilities: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="What the agent needs: 'llm', 'web_search', 'code_execution', etc."
|
||||
description="What the agent needs: 'llm', 'web_search', 'code_execution', etc.",
|
||||
)
|
||||
|
||||
# Input/output schema
|
||||
input_schema: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Expected input format"
|
||||
)
|
||||
input_schema: dict[str, Any] = Field(default_factory=dict, description="Expected input format")
|
||||
output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Expected output format"
|
||||
default_factory=dict, description="Expected output format"
|
||||
)
|
||||
|
||||
# Versioning for evolution
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any
|
||||
|
||||
class HITLInputType(str, Enum):
|
||||
"""Type of input expected from human."""
|
||||
|
||||
FREE_TEXT = "free_text" # Open-ended text response
|
||||
STRUCTURED = "structured" # Specific fields to fill
|
||||
SELECTION = "selection" # Choose from options
|
||||
@@ -22,6 +23,7 @@ class HITLInputType(str, Enum):
|
||||
@dataclass
|
||||
class HITLQuestion:
|
||||
"""A single question to ask the human."""
|
||||
|
||||
id: str
|
||||
question: str
|
||||
input_type: HITLInputType = HITLInputType.FREE_TEXT
|
||||
@@ -44,6 +46,7 @@ class HITLRequest:
|
||||
|
||||
This is what the agent produces when it needs human input.
|
||||
"""
|
||||
|
||||
# Context
|
||||
objective: str # What we're trying to accomplish
|
||||
current_state: str # Where we are in the process
|
||||
@@ -92,6 +95,7 @@ class HITLResponse:
|
||||
|
||||
This is what gets passed back when resuming from a pause.
|
||||
"""
|
||||
|
||||
# Original request reference
|
||||
request_id: str
|
||||
|
||||
@@ -170,13 +174,13 @@ class HITLProtocol:
|
||||
|
||||
# Use Haiku to extract answers
|
||||
try:
|
||||
import anthropic
|
||||
import json
|
||||
|
||||
questions_str = "\n".join([
|
||||
f"{i+1}. {q.question} (id: {q.id})"
|
||||
for i, q in enumerate(request.questions)
|
||||
])
|
||||
import anthropic
|
||||
|
||||
questions_str = "\n".join(
|
||||
[f"{i + 1}. {q.question} (id: {q.id})" for i, q in enumerate(request.questions)]
|
||||
)
|
||||
|
||||
prompt = f"""Parse the user's response and extract answers for each question.
|
||||
|
||||
@@ -195,13 +199,14 @@ Example format:
|
||||
message = client.messages.create(
|
||||
model="claude-3-5-haiku-20241022",
|
||||
max_tokens=500,
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
|
||||
# Parse Haiku's response
|
||||
import re
|
||||
|
||||
response_text = message.content[0].text.strip()
|
||||
json_match = re.search(r'\{[^{}]*\}', response_text, re.DOTALL)
|
||||
json_match = re.search(r"\{[^{}]*\}", response_text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
parsed = json.loads(json_match.group())
|
||||
|
||||
@@ -8,23 +8,24 @@ The HybridJudge evaluates step execution results using:
|
||||
Escalation path: rules → LLM → human
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.code_sandbox import safe_eval
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.plan import (
|
||||
PlanStep,
|
||||
EvaluationRule,
|
||||
Judgment,
|
||||
JudgmentAction,
|
||||
EvaluationRule,
|
||||
PlanStep,
|
||||
)
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.code_sandbox import safe_eval
|
||||
from framework.llm.provider import LLMProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuleEvaluationResult:
|
||||
"""Result of rule-based evaluation."""
|
||||
|
||||
is_definitive: bool # True if a rule matched definitively
|
||||
judgment: Judgment | None = None
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -136,9 +137,9 @@ class HybridJudge:
|
||||
|
||||
# Build evaluation context
|
||||
eval_context = {
|
||||
"step": step.model_dump() if hasattr(step, 'model_dump') else step,
|
||||
"step": step.model_dump() if hasattr(step, "model_dump") else step,
|
||||
"result": result,
|
||||
"goal": goal.model_dump() if hasattr(goal, 'model_dump') else goal,
|
||||
"goal": goal.model_dump() if hasattr(goal, "model_dump") else goal,
|
||||
"context": context,
|
||||
"success": isinstance(result, dict) and result.get("success", False),
|
||||
"error": isinstance(result, dict) and result.get("error"),
|
||||
@@ -216,7 +217,10 @@ class HybridJudge:
|
||||
# Low confidence - escalate
|
||||
return Judgment(
|
||||
action=JudgmentAction.ESCALATE,
|
||||
reasoning=f"LLM confidence ({judgment.confidence:.2f}) below threshold ({self.llm_confidence_threshold})",
|
||||
reasoning=(
|
||||
f"LLM confidence ({judgment.confidence:.2f}) "
|
||||
f"below threshold ({self.llm_confidence_threshold})"
|
||||
),
|
||||
feedback=judgment.feedback,
|
||||
confidence=judgment.confidence,
|
||||
llm_used=True,
|
||||
@@ -338,52 +342,65 @@ def create_default_judge(llm: LLMProvider | None = None) -> HybridJudge:
|
||||
judge = HybridJudge(llm=llm)
|
||||
|
||||
# Rule: Accept on explicit success flag
|
||||
judge.add_rule(EvaluationRule(
|
||||
id="explicit_success",
|
||||
description="Step explicitly marked as successful",
|
||||
condition="isinstance(result, dict) and result.get('success') == True",
|
||||
action=JudgmentAction.ACCEPT,
|
||||
priority=100,
|
||||
))
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="explicit_success",
|
||||
description="Step explicitly marked as successful",
|
||||
condition="isinstance(result, dict) and result.get('success') == True",
|
||||
action=JudgmentAction.ACCEPT,
|
||||
priority=100,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Retry on transient errors
|
||||
judge.add_rule(EvaluationRule(
|
||||
id="transient_error_retry",
|
||||
description="Transient error that can be retried",
|
||||
condition="isinstance(result, dict) and result.get('error_type') in ['timeout', 'rate_limit', 'connection_error']",
|
||||
action=JudgmentAction.RETRY,
|
||||
feedback_template="Transient error: {result[error]}. Please retry.",
|
||||
priority=90,
|
||||
))
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="transient_error_retry",
|
||||
description="Transient error that can be retried",
|
||||
condition=(
|
||||
"isinstance(result, dict) and "
|
||||
"result.get('error_type') in ['timeout', 'rate_limit', 'connection_error']"
|
||||
),
|
||||
action=JudgmentAction.RETRY,
|
||||
feedback_template="Transient error: {result[error]}. Please retry.",
|
||||
priority=90,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Replan on missing data
|
||||
judge.add_rule(EvaluationRule(
|
||||
id="missing_data_replan",
|
||||
description="Required data not available",
|
||||
condition="isinstance(result, dict) and result.get('error_type') == 'missing_data'",
|
||||
action=JudgmentAction.REPLAN,
|
||||
feedback_template="Missing required data: {result[error]}. Plan needs adjustment.",
|
||||
priority=80,
|
||||
))
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="missing_data_replan",
|
||||
description="Required data not available",
|
||||
condition="isinstance(result, dict) and result.get('error_type') == 'missing_data'",
|
||||
action=JudgmentAction.REPLAN,
|
||||
feedback_template="Missing required data: {result[error]}. Plan needs adjustment.",
|
||||
priority=80,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Escalate on security issues
|
||||
judge.add_rule(EvaluationRule(
|
||||
id="security_escalate",
|
||||
description="Security issue detected",
|
||||
condition="isinstance(result, dict) and result.get('error_type') == 'security'",
|
||||
action=JudgmentAction.ESCALATE,
|
||||
feedback_template="Security issue detected: {result[error]}",
|
||||
priority=200,
|
||||
))
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="security_escalate",
|
||||
description="Security issue detected",
|
||||
condition="isinstance(result, dict) and result.get('error_type') == 'security'",
|
||||
action=JudgmentAction.ESCALATE,
|
||||
feedback_template="Security issue detected: {result[error]}",
|
||||
priority=200,
|
||||
)
|
||||
)
|
||||
|
||||
# Rule: Fail on max retries exceeded
|
||||
judge.add_rule(EvaluationRule(
|
||||
id="max_retries_fail",
|
||||
description="Maximum retries exceeded",
|
||||
condition="step.get('attempts', 0) >= step.get('max_retries', 3)",
|
||||
action=JudgmentAction.REPLAN,
|
||||
feedback_template="Step '{step[id]}' failed after {step[attempts]} attempts",
|
||||
priority=150,
|
||||
))
|
||||
judge.add_rule(
|
||||
EvaluationRule(
|
||||
id="max_retries_fail",
|
||||
description="Maximum retries exceeded",
|
||||
condition="step.get('attempts', 0) >= step.get('max_retries', 3)",
|
||||
action=JudgmentAction.REPLAN,
|
||||
feedback_template="Step '{step[id]}' failed after {step[attempts]} attempts",
|
||||
priority=150,
|
||||
)
|
||||
)
|
||||
|
||||
return judge
|
||||
|
||||
+676
-137
File diff suppressed because it is too large
Load Diff
@@ -16,6 +16,50 @@ from typing import Any
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _heuristic_repair(text: str) -> dict | None:
|
||||
"""
|
||||
Attempt to repair JSON without an LLM call.
|
||||
|
||||
Handles common errors:
|
||||
- Markdown code blocks
|
||||
- Python booleans/None (True -> true)
|
||||
- Single quotes instead of double quotes
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
|
||||
# 1. Strip Markdown code blocks
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\s*```$", "", text, flags=re.MULTILINE)
|
||||
text = text.strip()
|
||||
|
||||
# 2. Find outermost JSON-like structure (greedy match)
|
||||
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
|
||||
if match:
|
||||
candidate = match.group(1)
|
||||
|
||||
# 3. Common fixes
|
||||
# Fix Python constants
|
||||
candidate = re.sub(r"\bTrue\b", "true", candidate)
|
||||
candidate = re.sub(r"\bFalse\b", "false", candidate)
|
||||
candidate = re.sub(r"\bNone\b", "null", candidate)
|
||||
|
||||
# 4. Attempt load
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
# 5. Advanced: Try swapping single quotes if double quotes fail
|
||||
# This is risky but effective for simple dicts
|
||||
try:
|
||||
if "'" in candidate and '"' not in candidate:
|
||||
candidate_swapped = candidate.replace("'", '"')
|
||||
return json.loads(candidate_swapped)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CleansingConfig:
|
||||
"""Configuration for output cleansing."""
|
||||
@@ -42,30 +86,8 @@ class OutputCleaner:
|
||||
"""
|
||||
Framework-level output validation and cleaning.
|
||||
|
||||
Uses fast LLM (llama-3.3-70b) to clean malformed outputs
|
||||
Uses heuristics and fast LLM to clean malformed outputs
|
||||
before they flow to the next node.
|
||||
|
||||
Example:
|
||||
cleaner = OutputCleaner(
|
||||
config=CleansingConfig(enabled=True),
|
||||
llm_provider=llm,
|
||||
)
|
||||
|
||||
# Validate output
|
||||
validation = cleaner.validate_output(
|
||||
output=node_output,
|
||||
source_node_id="analyze",
|
||||
target_node_spec=next_node_spec,
|
||||
)
|
||||
|
||||
if not validation.valid:
|
||||
# Clean the output
|
||||
cleaned = cleaner.clean_output(
|
||||
output=node_output,
|
||||
source_node_id="analyze",
|
||||
target_node_spec=next_node_spec,
|
||||
validation_errors=validation.errors,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, config: CleansingConfig, llm_provider=None):
|
||||
@@ -74,8 +96,7 @@ class OutputCleaner:
|
||||
|
||||
Args:
|
||||
config: Cleansing configuration
|
||||
llm_provider: Optional LLM provider. If None and cleaning is enabled,
|
||||
will create a LiteLLMProvider with the configured fast_model.
|
||||
llm_provider: Optional LLM provider.
|
||||
"""
|
||||
self.config = config
|
||||
self.success_cache: dict[str, Any] = {} # Cache successful patterns
|
||||
@@ -88,9 +109,10 @@ class OutputCleaner:
|
||||
elif config.enabled:
|
||||
# Create dedicated fast LLM provider for cleaning
|
||||
try:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
import os
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
api_key = os.environ.get("CEREBRAS_API_KEY")
|
||||
if api_key:
|
||||
self.llm = LiteLLMProvider(
|
||||
@@ -98,13 +120,9 @@ class OutputCleaner:
|
||||
model=config.fast_model,
|
||||
temperature=0.0, # Deterministic cleaning
|
||||
)
|
||||
logger.info(
|
||||
f"✓ Initialized OutputCleaner with {config.fast_model}"
|
||||
)
|
||||
logger.info(f"✓ Initialized OutputCleaner with {config.fast_model}")
|
||||
else:
|
||||
logger.warning(
|
||||
"⚠ CEREBRAS_API_KEY not found, output cleaning will be disabled"
|
||||
)
|
||||
logger.warning("⚠ CEREBRAS_API_KEY not found, output cleaning will be disabled")
|
||||
self.llm = None
|
||||
except ImportError:
|
||||
logger.warning("⚠ LiteLLMProvider not available, output cleaning disabled")
|
||||
@@ -121,11 +139,6 @@ class OutputCleaner:
|
||||
"""
|
||||
Validate output matches target node's expected input schema.
|
||||
|
||||
Args:
|
||||
output: Output from source node
|
||||
source_node_id: ID of source node
|
||||
target_node_spec: Spec of target node (for input_keys)
|
||||
|
||||
Returns:
|
||||
ValidationResult with errors and optionally cleaned output
|
||||
"""
|
||||
@@ -199,7 +212,7 @@ class OutputCleaner:
|
||||
validation_errors: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Use fast LLM to clean malformed output.
|
||||
Use heuristics and fast LLM to clean malformed output.
|
||||
|
||||
Args:
|
||||
output: Raw output from source node
|
||||
@@ -209,14 +222,36 @@ class OutputCleaner:
|
||||
|
||||
Returns:
|
||||
Cleaned output matching target schema
|
||||
|
||||
Raises:
|
||||
Exception: If cleaning fails and fallback_to_raw is False
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
logger.warning("⚠ Output cleansing disabled in config")
|
||||
return output
|
||||
|
||||
# --- PHASE 1: Fast Heuristic Repair (Avoids LLM call) ---
|
||||
# Often the output is just a string containing JSON, or has minor syntax errors
|
||||
# If output is a dictionary but malformed, we might need to serialize it first
|
||||
# to try and fix the underlying string representation if it came from raw text
|
||||
|
||||
# Heuristic: Check if any value is actually a JSON string that should be promoted
|
||||
# This handles the "JSON Parsing Trap" where LLM returns {"key": "{\"nested\": ...}"}
|
||||
heuristic_fixed = False
|
||||
fixed_output = output.copy()
|
||||
|
||||
for key, value in output.items():
|
||||
if isinstance(value, str):
|
||||
repaired = _heuristic_repair(value)
|
||||
if repaired and isinstance(repaired, dict | list):
|
||||
# Check if this repaired structure looks like what we want
|
||||
# e.g. if the key is 'data' and the string contained valid JSON
|
||||
fixed_output[key] = repaired
|
||||
heuristic_fixed = True
|
||||
|
||||
# If we fixed something, re-validate manually to see if it's enough
|
||||
if heuristic_fixed:
|
||||
logger.info("⚡ Heuristic repair applied (nested JSON expansion)")
|
||||
return fixed_output
|
||||
|
||||
# --- PHASE 2: LLM-based Repair ---
|
||||
if not self.llm:
|
||||
logger.warning("⚠ No LLM provider available for cleansing")
|
||||
return output
|
||||
@@ -253,22 +288,21 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
|
||||
|
||||
response = self.llm.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You clean malformed agent outputs. Return only valid JSON matching the schema.",
|
||||
system=(
|
||||
"You clean malformed agent outputs. Return only valid JSON matching the schema."
|
||||
),
|
||||
max_tokens=2048, # Sufficient for cleaning most outputs
|
||||
)
|
||||
|
||||
# Parse cleaned output
|
||||
cleaned_text = response.content.strip()
|
||||
|
||||
# Remove markdown if present
|
||||
if cleaned_text.startswith("```"):
|
||||
match = re.search(
|
||||
r"```(?:json)?\s*\n?(.*?)\n?```", cleaned_text, re.DOTALL
|
||||
)
|
||||
if match:
|
||||
cleaned_text = match.group(1).strip()
|
||||
# Apply heuristic repair to the LLM's output too (just in case)
|
||||
cleaned = _heuristic_repair(cleaned_text)
|
||||
|
||||
cleaned = json.loads(cleaned_text)
|
||||
if not cleaned:
|
||||
# Fallback to standard load if heuristic returns None (unlikely for LLM output)
|
||||
cleaned = json.loads(cleaned_text)
|
||||
|
||||
if isinstance(cleaned, dict):
|
||||
self.cleansing_count += 1
|
||||
@@ -278,15 +312,11 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
|
||||
)
|
||||
return cleaned
|
||||
else:
|
||||
logger.warning(
|
||||
f"⚠ Cleaned output is not a dict: {type(cleaned)}"
|
||||
)
|
||||
logger.warning(f"⚠ Cleaned output is not a dict: {type(cleaned)}")
|
||||
if self.config.fallback_to_raw:
|
||||
return output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cleaning produced {type(cleaned)}, expected dict"
|
||||
)
|
||||
raise ValueError(f"Cleaning produced {type(cleaned)}, expected dict")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"✗ Failed to parse cleaned JSON: {e}")
|
||||
@@ -318,7 +348,7 @@ Return ONLY valid JSON matching the expected schema. No explanations, no markdow
|
||||
|
||||
line = f' "{key}": {type_hint}'
|
||||
if description:
|
||||
line += f' // {description}'
|
||||
line += f" // {description}"
|
||||
if required:
|
||||
line += " (required)"
|
||||
lines.append(line + ",")
|
||||
|
||||
@@ -10,24 +10,26 @@ The Plan is the contract between the external planner and the executor:
|
||||
- If replanning needed, returns feedback to external planner
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
"""Types of actions a PlanStep can perform."""
|
||||
LLM_CALL = "llm_call" # Call LLM for generation
|
||||
TOOL_USE = "tool_use" # Use a registered tool
|
||||
SUB_GRAPH = "sub_graph" # Execute a sub-graph
|
||||
FUNCTION = "function" # Call a Python function
|
||||
|
||||
LLM_CALL = "llm_call" # Call LLM for generation
|
||||
TOOL_USE = "tool_use" # Use a registered tool
|
||||
SUB_GRAPH = "sub_graph" # Execute a sub-graph
|
||||
FUNCTION = "function" # Call a Python function
|
||||
CODE_EXECUTION = "code_execution" # Execute dynamic code (sandboxed)
|
||||
|
||||
|
||||
class StepStatus(str, Enum):
|
||||
"""Status of a plan step."""
|
||||
|
||||
PENDING = "pending"
|
||||
AWAITING_APPROVAL = "awaiting_approval" # Waiting for human approval
|
||||
IN_PROGRESS = "in_progress"
|
||||
@@ -36,17 +38,36 @@ class StepStatus(str, Enum):
|
||||
SKIPPED = "skipped"
|
||||
REJECTED = "rejected" # Human rejected execution
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
"""Check if this status represents a terminal (finished) state.
|
||||
|
||||
Terminal states are states where the step will not execute further,
|
||||
either because it completed successfully or failed/was skipped.
|
||||
"""
|
||||
return self in (
|
||||
StepStatus.COMPLETED,
|
||||
StepStatus.FAILED,
|
||||
StepStatus.SKIPPED,
|
||||
StepStatus.REJECTED,
|
||||
)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if this status represents successful completion."""
|
||||
return self == StepStatus.COMPLETED
|
||||
|
||||
|
||||
class ApprovalDecision(str, Enum):
|
||||
"""Human decision on a step requiring approval."""
|
||||
APPROVE = "approve" # Execute as planned
|
||||
REJECT = "reject" # Skip this step
|
||||
MODIFY = "modify" # Execute with modifications
|
||||
ABORT = "abort" # Stop entire execution
|
||||
|
||||
APPROVE = "approve" # Execute as planned
|
||||
REJECT = "reject" # Skip this step
|
||||
MODIFY = "modify" # Execute with modifications
|
||||
ABORT = "abort" # Stop entire execution
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request for human approval before executing a step."""
|
||||
|
||||
step_id: str
|
||||
step_description: str
|
||||
action_type: str
|
||||
@@ -62,6 +83,7 @@ class ApprovalRequest(BaseModel):
|
||||
|
||||
class ApprovalResult(BaseModel):
|
||||
"""Result of human approval decision."""
|
||||
|
||||
decision: ApprovalDecision
|
||||
reason: str | None = None
|
||||
modifications: dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -71,10 +93,11 @@ class ApprovalResult(BaseModel):
|
||||
|
||||
class JudgmentAction(str, Enum):
|
||||
"""Actions the judge can take after evaluating a step."""
|
||||
ACCEPT = "accept" # Step completed successfully, continue
|
||||
RETRY = "retry" # Retry the step with feedback
|
||||
REPLAN = "replan" # Return to external planner for new plan
|
||||
ESCALATE = "escalate" # Request human intervention
|
||||
|
||||
ACCEPT = "accept" # Step completed successfully, continue
|
||||
RETRY = "retry" # Retry the step with feedback
|
||||
REPLAN = "replan" # Return to external planner for new plan
|
||||
ESCALATE = "escalate" # Request human intervention
|
||||
|
||||
|
||||
class ActionSpec(BaseModel):
|
||||
@@ -83,6 +106,7 @@ class ActionSpec(BaseModel):
|
||||
|
||||
This is the "what to do" part of a PlanStep.
|
||||
"""
|
||||
|
||||
action_type: ActionType
|
||||
|
||||
# For LLM_CALL
|
||||
@@ -114,6 +138,7 @@ class PlanStep(BaseModel):
|
||||
|
||||
Created by external planner, executed by Worker, evaluated by Judge.
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
action: ActionSpec
|
||||
@@ -121,27 +146,23 @@ class PlanStep(BaseModel):
|
||||
# Data flow
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Input data for this step (can reference previous step outputs)"
|
||||
description="Input data for this step (can reference previous step outputs)",
|
||||
)
|
||||
expected_outputs: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Keys this step should produce"
|
||||
default_factory=list, description="Keys this step should produce"
|
||||
)
|
||||
|
||||
# Dependencies
|
||||
dependencies: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of steps that must complete before this one"
|
||||
default_factory=list, description="IDs of steps that must complete before this one"
|
||||
)
|
||||
|
||||
# Human-in-the-loop (HITL)
|
||||
requires_approval: bool = Field(
|
||||
default=False,
|
||||
description="If True, requires human approval before execution"
|
||||
default=False, description="If True, requires human approval before execution"
|
||||
)
|
||||
approval_message: str | None = Field(
|
||||
default=None,
|
||||
description="Message to show human when requesting approval"
|
||||
default=None, description="Message to show human when requesting approval"
|
||||
)
|
||||
|
||||
# Execution state
|
||||
@@ -157,11 +178,23 @@ class PlanStep(BaseModel):
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def is_ready(self, completed_step_ids: set[str]) -> bool:
|
||||
"""Check if this step is ready to execute (all dependencies met)."""
|
||||
def is_ready(self, terminal_step_ids: set[str]) -> bool:
|
||||
"""Check if this step is ready to execute (all dependencies finished).
|
||||
|
||||
A step is ready when:
|
||||
1. Its status is PENDING (not yet started)
|
||||
2. All its dependencies are in a terminal state (completed, failed, skipped, or rejected)
|
||||
|
||||
Note: This allows dependent steps to become "ready" even if their dependencies
|
||||
failed. The executor should check if any dependencies failed and handle
|
||||
accordingly (e.g., skip the step or mark it as blocked).
|
||||
|
||||
Args:
|
||||
terminal_step_ids: Set of step IDs that are in a terminal state
|
||||
"""
|
||||
if self.status != StepStatus.PENDING:
|
||||
return False
|
||||
return all(dep in completed_step_ids for dep in self.dependencies)
|
||||
return all(dep in terminal_step_ids for dep in self.dependencies)
|
||||
|
||||
|
||||
class Judgment(BaseModel):
|
||||
@@ -170,6 +203,7 @@ class Judgment(BaseModel):
|
||||
|
||||
The Judge evaluates step results and decides what to do next.
|
||||
"""
|
||||
|
||||
action: JudgmentAction
|
||||
reasoning: str
|
||||
feedback: str | None = None # For retry/replan - what went wrong
|
||||
@@ -193,6 +227,7 @@ class EvaluationRule(BaseModel):
|
||||
|
||||
Rules are checked before falling back to LLM evaluation.
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str
|
||||
|
||||
@@ -216,6 +251,7 @@ class Plan(BaseModel):
|
||||
Created by external planner (Claude Code, etc).
|
||||
Executed by FlexibleGraphExecutor.
|
||||
"""
|
||||
|
||||
id: str
|
||||
goal_id: str
|
||||
description: str
|
||||
@@ -320,18 +356,46 @@ class Plan(BaseModel):
|
||||
return None
|
||||
|
||||
def get_ready_steps(self) -> list[PlanStep]:
|
||||
"""Get all steps that are ready to execute."""
|
||||
completed_ids = {s.id for s in self.steps if s.status == StepStatus.COMPLETED}
|
||||
return [s for s in self.steps if s.is_ready(completed_ids)]
|
||||
"""Get all steps that are ready to execute.
|
||||
|
||||
A step is ready when all its dependencies are in terminal states
|
||||
(completed, failed, skipped, or rejected).
|
||||
"""
|
||||
terminal_ids = {s.id for s in self.steps if s.status.is_terminal()}
|
||||
return [s for s in self.steps if s.is_ready(terminal_ids)]
|
||||
|
||||
def get_completed_steps(self) -> list[PlanStep]:
|
||||
"""Get all completed steps."""
|
||||
return [s for s in self.steps if s.status == StepStatus.COMPLETED]
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all steps are completed."""
|
||||
"""Check if all steps are in terminal states (finished executing).
|
||||
|
||||
Returns True when all steps have reached a terminal state, regardless
|
||||
of whether they succeeded or failed. Use has_failed_steps() to check
|
||||
if any steps failed.
|
||||
"""
|
||||
return all(s.status.is_terminal() for s in self.steps)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if all steps completed successfully."""
|
||||
return all(s.status == StepStatus.COMPLETED for s in self.steps)
|
||||
|
||||
def has_failed_steps(self) -> bool:
|
||||
"""Check if any steps failed, were skipped, or were rejected."""
|
||||
return any(
|
||||
s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
|
||||
for s in self.steps
|
||||
)
|
||||
|
||||
def get_failed_steps(self) -> list[PlanStep]:
|
||||
"""Get all steps that failed, were skipped, or were rejected."""
|
||||
return [
|
||||
s
|
||||
for s in self.steps
|
||||
if s.status in (StepStatus.FAILED, StepStatus.SKIPPED, StepStatus.REJECTED)
|
||||
]
|
||||
|
||||
def to_feedback_context(self) -> dict[str, Any]:
|
||||
"""Create context for replanning."""
|
||||
return {
|
||||
@@ -361,12 +425,13 @@ class Plan(BaseModel):
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
"""Status of plan execution."""
|
||||
|
||||
COMPLETED = "completed"
|
||||
AWAITING_APPROVAL = "awaiting_approval" # Paused for human approval
|
||||
NEEDS_REPLAN = "needs_replan"
|
||||
NEEDS_ESCALATION = "needs_escalation"
|
||||
REJECTED = "rejected" # Human rejected a step
|
||||
ABORTED = "aborted" # Human aborted execution
|
||||
ABORTED = "aborted" # Human aborted execution
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@@ -376,6 +441,7 @@ class PlanExecutionResult(BaseModel):
|
||||
|
||||
Returned to external planner with status and feedback.
|
||||
"""
|
||||
|
||||
status: ExecutionStatus
|
||||
|
||||
# Results from completed steps
|
||||
@@ -421,6 +487,7 @@ def load_export(data: str | dict) -> tuple["Plan", Any]:
|
||||
result = await executor.execute_plan(plan, goal, context)
|
||||
"""
|
||||
import json as json_module
|
||||
|
||||
from framework.graph.goal import Goal
|
||||
|
||||
if isinstance(data, str):
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
# Safe operators whitelist
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
ast.LShift: operator.lshift,
|
||||
ast.RShift: operator.rshift,
|
||||
ast.BitOr: operator.or_,
|
||||
ast.BitXor: operator.xor,
|
||||
ast.BitAnd: operator.and_,
|
||||
ast.Eq: operator.eq,
|
||||
ast.NotEq: operator.ne,
|
||||
ast.Lt: operator.lt,
|
||||
ast.LtE: operator.le,
|
||||
ast.Gt: operator.gt,
|
||||
ast.GtE: operator.ge,
|
||||
ast.Is: operator.is_,
|
||||
ast.IsNot: operator.is_not,
|
||||
ast.In: lambda x, y: x in y,
|
||||
ast.NotIn: lambda x, y: x not in y,
|
||||
ast.USub: operator.neg,
|
||||
ast.UAdd: operator.pos,
|
||||
ast.Not: operator.not_,
|
||||
ast.Invert: operator.inv,
|
||||
}
|
||||
|
||||
# Safe functions whitelist
|
||||
SAFE_FUNCTIONS = {
|
||||
"len": len,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"all": all,
|
||||
"any": any,
|
||||
}
|
||||
|
||||
|
||||
class SafeEvalVisitor(ast.NodeVisitor):
|
||||
def __init__(self, context: dict[str, Any]):
|
||||
self.context = context
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
# Override visit to prevent default behavior and ensure only explicitly allowed nodes work
|
||||
method = "visit_" + node.__class__.__name__
|
||||
visitor = getattr(self, method, self.generic_visit)
|
||||
return visitor(node)
|
||||
|
||||
def generic_visit(self, node: ast.AST):
|
||||
raise ValueError(f"Use of {node.__class__.__name__} is not allowed")
|
||||
|
||||
def visit_Expression(self, node: ast.Expression) -> Any:
|
||||
return self.visit(node.body)
|
||||
|
||||
def visit_Expr(self, node: ast.Expr) -> Any:
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_Constant(self, node: ast.Constant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Number/String/Bytes/NameConstant (Python < 3.8 compat if needed) ---
|
||||
def visit_Num(self, node: ast.Num) -> Any:
|
||||
return node.n
|
||||
|
||||
def visit_Str(self, node: ast.Str) -> Any:
|
||||
return node.s
|
||||
|
||||
def visit_NameConstant(self, node: ast.NameConstant) -> Any:
|
||||
return node.value
|
||||
|
||||
# --- Data Structures ---
|
||||
def visit_List(self, node: ast.List) -> list:
|
||||
return [self.visit(elt) for elt in node.elts]
|
||||
|
||||
def visit_Tuple(self, node: ast.Tuple) -> tuple:
|
||||
return tuple(self.visit(elt) for elt in node.elts)
|
||||
|
||||
def visit_Dict(self, node: ast.Dict) -> dict:
|
||||
return {
|
||||
self.visit(k): self.visit(v)
|
||||
for k, v in zip(node.keys, node.values, strict=False)
|
||||
if k is not None
|
||||
}
|
||||
|
||||
# --- Operations ---
|
||||
def visit_BinOp(self, node: ast.BinOp) -> Any:
|
||||
op_func = SAFE_OPERATORS.get(type(node.op))
|
||||
if op_func is None:
|
||||
raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
|
||||
return op_func(self.visit(node.left), self.visit(node.right))
|
||||
|
||||
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
|
||||
op_func = SAFE_OPERATORS.get(type(node.op))
|
||||
if op_func is None:
|
||||
raise ValueError(f"Operator {type(node.op).__name__} is not allowed")
|
||||
return op_func(self.visit(node.operand))
|
||||
|
||||
def visit_Compare(self, node: ast.Compare) -> Any:
|
||||
left = self.visit(node.left)
|
||||
for op, comparator in zip(node.ops, node.comparators, strict=False):
|
||||
op_func = SAFE_OPERATORS.get(type(op))
|
||||
if op_func is None:
|
||||
raise ValueError(f"Operator {type(op).__name__} is not allowed")
|
||||
right = self.visit(comparator)
|
||||
if not op_func(left, right):
|
||||
return False
|
||||
left = right # Chain comparisons
|
||||
return True
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
|
||||
values = [self.visit(v) for v in node.values]
|
||||
if isinstance(node.op, ast.And):
|
||||
return all(values)
|
||||
elif isinstance(node.op, ast.Or):
|
||||
return any(values)
|
||||
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
|
||||
|
||||
def visit_IfExp(self, node: ast.IfExp) -> Any:
|
||||
# Ternary: true_val if test else false_val
|
||||
if self.visit(node.test):
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
return self.visit(node.orelse)
|
||||
|
||||
# --- Variables and Attributes ---
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
if node.id in self.context:
|
||||
return self.context[node.id]
|
||||
raise NameError(f"Name '{node.id}' is not defined")
|
||||
raise ValueError("Only reading variables is allowed")
|
||||
|
||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||
# value[slice]
|
||||
val = self.visit(node.value)
|
||||
idx = self.visit(node.slice)
|
||||
return val[idx]
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
# value.attr
|
||||
# STIRCT CHECK: No access to private attributes (starting with _)
|
||||
if node.attr.startswith("_"):
|
||||
raise ValueError(f"Access to private attribute '{node.attr}' is not allowed")
|
||||
|
||||
val = self.visit(node.value)
|
||||
|
||||
# Safe attribute access: only allow if it's in the dict (if val is dict)
|
||||
# or it's a safe property of a basic type?
|
||||
# Actually, for flexibility, people often use dot access for dicts in these expressions.
|
||||
# But standard Python dict doesn't support dot access.
|
||||
# If val is a dict, Attribute access usually fails in Python unless wrapped.
|
||||
# If the user context provides objects, we might want to allow attribute access.
|
||||
# BUT we must be careful not to allow access to dangerous things like __class__ etc.
|
||||
# The check starts_with("_") covers __class__, __init__, etc.
|
||||
|
||||
try:
|
||||
return getattr(val, node.attr)
|
||||
except AttributeError:
|
||||
# Fallback: maybe it's a dict and they want dot access?
|
||||
# (Only if we want to support that sugar, usually not standard python)
|
||||
# Let's stick to standard python behavior + strict private check.
|
||||
pass
|
||||
|
||||
raise AttributeError(f"Object has no attribute '{node.attr}'")
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
# Only allow calling whitelisted functions
|
||||
func = self.visit(node.func)
|
||||
|
||||
# Check if the function object itself is in our whitelist values
|
||||
# This is tricky because `func` is the actual function object,
|
||||
# but we also want to verify it came from a safe place.
|
||||
# Easier: Check if node.func is a Name and that name is in SAFE_FUNCTIONS.
|
||||
|
||||
is_safe = False
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in SAFE_FUNCTIONS:
|
||||
is_safe = True
|
||||
|
||||
# Also allow methods on objects if they are safe?
|
||||
# E.g. "somestring".lower() or list.append() (if we allowed mutation, but we don't for now)
|
||||
# For now, restrict to SAFE_FUNCTIONS whitelist for global calls and deny method calls
|
||||
# unless we explicitly add safe methods.
|
||||
# Allowing method calls on strings/lists (split, join, get) is commonly needed.
|
||||
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
# Method call.
|
||||
# Allow basic safe methods?
|
||||
# For security, start strict. Only helper functions.
|
||||
# Re-visiting: User might want 'output.get("key")'.
|
||||
method_name = node.func.attr
|
||||
if method_name in [
|
||||
"get",
|
||||
"keys",
|
||||
"values",
|
||||
"items",
|
||||
"lower",
|
||||
"upper",
|
||||
"strip",
|
||||
"split",
|
||||
]:
|
||||
is_safe = True
|
||||
|
||||
if not is_safe and func not in SAFE_FUNCTIONS.values():
|
||||
raise ValueError("Call to function/method is not allowed")
|
||||
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords}
|
||||
|
||||
return func(*args, **keywords)
|
||||
|
||||
def visit_Index(self, node: ast.Index) -> Any:
|
||||
# Python < 3.9
|
||||
return self.visit(node.value)
|
||||
|
||||
|
||||
def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
|
||||
"""
|
||||
Safely evaluate a python expression string.
|
||||
|
||||
Args:
|
||||
expr: The expression string to evaluate.
|
||||
context: Dictionary of variables available in the expression.
|
||||
|
||||
Returns:
|
||||
The result of the evaluation.
|
||||
|
||||
Raises:
|
||||
ValueError: If unsafe operations or syntax are detected.
|
||||
SyntaxError: If the expression is invalid Python.
|
||||
"""
|
||||
if context is None:
|
||||
context = {}
|
||||
|
||||
# Add safe builtins to context
|
||||
full_context = context.copy()
|
||||
full_context.update(SAFE_FUNCTIONS)
|
||||
|
||||
try:
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
|
||||
|
||||
visitor = SafeEvalVisitor(full_context)
|
||||
return visitor.visit(tree)
|
||||
@@ -6,8 +6,9 @@ Demonstrates how OutputCleaner fixes the JSON parsing trap using llama-3.3-70b.
|
||||
|
||||
import json
|
||||
import os
|
||||
from framework.graph.output_cleaner import OutputCleaner, CleansingConfig
|
||||
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.graph.output_cleaner import CleansingConfig, OutputCleaner
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
|
||||
@@ -42,7 +43,10 @@ def test_cleaning_with_cerebras():
|
||||
# Scenario 1: JSON parsing trap (entire response in one key)
|
||||
print("\n--- Scenario 1: JSON Parsing Trap ---")
|
||||
malformed_output = {
|
||||
"recommendation": '{\n "approval_decision": "APPROVED",\n "risk_score": 3.5,\n "reason": "Standard terms, low risk"\n}',
|
||||
"recommendation": (
|
||||
'{\n "approval_decision": "APPROVED",\n "risk_score": 3.5,\n '
|
||||
'"reason": "Standard terms, low risk"\n}'
|
||||
),
|
||||
}
|
||||
|
||||
target_spec = NodeSpec(
|
||||
@@ -84,14 +88,17 @@ def test_cleaning_with_cerebras():
|
||||
print(json.dumps(cleaned, indent=2))
|
||||
|
||||
assert isinstance(cleaned, dict), "Should return dict"
|
||||
assert "approval_decision" in str(cleaned) or isinstance(
|
||||
cleaned.get("recommendation"), dict
|
||||
), "Should have recommendation structure"
|
||||
assert "approval_decision" in str(cleaned) or isinstance(cleaned.get("recommendation"), dict), (
|
||||
"Should have recommendation structure"
|
||||
)
|
||||
|
||||
# Scenario 2: Multiple keys with JSON string
|
||||
print("\n\n--- Scenario 2: Multiple Keys, JSON String ---")
|
||||
malformed_output2 = {
|
||||
"analysis": '{"high_risk_clauses": ["unlimited liability"], "compliance_issues": [], "category": "high-risk"}',
|
||||
"analysis": (
|
||||
'{"high_risk_clauses": ["unlimited liability"], '
|
||||
'"compliance_issues": [], "category": "high-risk"}'
|
||||
),
|
||||
"risk_score": "7.5", # String instead of number
|
||||
}
|
||||
|
||||
@@ -131,9 +138,7 @@ def test_cleaning_with_cerebras():
|
||||
|
||||
assert isinstance(cleaned2, dict), "Should return dict"
|
||||
assert isinstance(cleaned2.get("analysis"), dict), "analysis should be dict"
|
||||
assert isinstance(
|
||||
cleaned2.get("risk_score"), (int, float)
|
||||
), "risk_score should be number"
|
||||
assert isinstance(cleaned2.get("risk_score"), (int, float)), "risk_score should be number"
|
||||
|
||||
# Stats
|
||||
stats = cleaner.get_stats()
|
||||
|
||||
@@ -8,12 +8,15 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of validating an output."""
|
||||
|
||||
success: bool
|
||||
errors: list[str]
|
||||
|
||||
@@ -30,11 +33,76 @@ class OutputValidator:
|
||||
Used by the executor to catch bad outputs before they pollute memory.
|
||||
"""
|
||||
|
||||
def _contains_code_indicators(self, value: str) -> bool:
|
||||
"""
|
||||
Check for code patterns in a string using sampling for efficiency.
|
||||
|
||||
For strings under 10KB, checks the entire content.
|
||||
For longer strings, samples at strategic positions to balance
|
||||
performance with detection accuracy.
|
||||
|
||||
Args:
|
||||
value: The string to check for code indicators
|
||||
|
||||
Returns:
|
||||
True if code indicators are found, False otherwise
|
||||
"""
|
||||
code_indicators = [
|
||||
# Python
|
||||
"def ",
|
||||
"class ",
|
||||
"import ",
|
||||
"from ",
|
||||
"if __name__",
|
||||
"async def ",
|
||||
"await ",
|
||||
"try:",
|
||||
"except:",
|
||||
# JavaScript/TypeScript
|
||||
"function ",
|
||||
"const ",
|
||||
"let ",
|
||||
"=> {",
|
||||
"require(",
|
||||
"export ",
|
||||
# SQL
|
||||
"SELECT ",
|
||||
"INSERT ",
|
||||
"UPDATE ",
|
||||
"DELETE ",
|
||||
"DROP ",
|
||||
# HTML/Script injection
|
||||
"<script",
|
||||
"<?php",
|
||||
"<%",
|
||||
]
|
||||
|
||||
# For strings under 10KB, check the entire content
|
||||
if len(value) < 10000:
|
||||
return any(indicator in value for indicator in code_indicators)
|
||||
|
||||
# For longer strings, sample at strategic positions
|
||||
sample_positions = [
|
||||
0, # Start
|
||||
len(value) // 4, # 25%
|
||||
len(value) // 2, # 50%
|
||||
3 * len(value) // 4, # 75%
|
||||
max(0, len(value) - 2000), # Near end
|
||||
]
|
||||
|
||||
for pos in sample_positions:
|
||||
chunk = value[pos : pos + 2000]
|
||||
if any(indicator in chunk for indicator in code_indicators):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def validate_output_keys(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
expected_keys: list[str],
|
||||
allow_empty: bool = False,
|
||||
nullable_keys: list[str] | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate that all expected keys are present and non-empty.
|
||||
@@ -43,16 +111,17 @@ class OutputValidator:
|
||||
output: The output dict to validate
|
||||
expected_keys: Keys that must be present
|
||||
allow_empty: If True, allow empty string values
|
||||
nullable_keys: Keys that are allowed to be None
|
||||
|
||||
Returns:
|
||||
ValidationResult with success status and any errors
|
||||
"""
|
||||
errors = []
|
||||
nullable_keys = nullable_keys or []
|
||||
|
||||
if not isinstance(output, dict):
|
||||
return ValidationResult(
|
||||
success=False,
|
||||
errors=[f"Output is not a dict, got {type(output).__name__}"]
|
||||
success=False, errors=[f"Output is not a dict, got {type(output).__name__}"]
|
||||
)
|
||||
|
||||
for key in expected_keys:
|
||||
@@ -61,12 +130,78 @@ class OutputValidator:
|
||||
elif not allow_empty:
|
||||
value = output[key]
|
||||
if value is None:
|
||||
errors.append(f"Output key '{key}' is None")
|
||||
if key not in nullable_keys:
|
||||
errors.append(f"Output key '{key}' is None")
|
||||
elif isinstance(value, str) and len(value.strip()) == 0:
|
||||
errors.append(f"Output key '{key}' is empty string")
|
||||
|
||||
return ValidationResult(success=len(errors) == 0, errors=errors)
|
||||
|
||||
def validate_with_pydantic(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
model: type[BaseModel],
|
||||
) -> tuple[ValidationResult, BaseModel | None]:
|
||||
"""
|
||||
Validate output against a Pydantic model.
|
||||
|
||||
Args:
|
||||
output: The output dict to validate
|
||||
model: Pydantic model class to validate against
|
||||
|
||||
Returns:
|
||||
Tuple of (ValidationResult, validated_model_instance or None)
|
||||
"""
|
||||
try:
|
||||
validated = model.model_validate(output)
|
||||
return ValidationResult(success=True, errors=[]), validated
|
||||
except ValidationError as e:
|
||||
errors = []
|
||||
for error in e.errors():
|
||||
field_path = ".".join(str(loc) for loc in error["loc"])
|
||||
msg = error["msg"]
|
||||
error_type = error["type"]
|
||||
errors.append(f"{field_path}: {msg} (type: {error_type})")
|
||||
return ValidationResult(success=False, errors=errors), None
|
||||
|
||||
def format_validation_feedback(
|
||||
self,
|
||||
validation_result: ValidationResult,
|
||||
model: type[BaseModel],
|
||||
) -> str:
|
||||
"""
|
||||
Format validation errors as feedback for LLM retry.
|
||||
|
||||
Args:
|
||||
validation_result: The failed validation result
|
||||
model: The Pydantic model that was used for validation
|
||||
|
||||
Returns:
|
||||
Formatted feedback string to include in retry prompt
|
||||
"""
|
||||
# Get the model's JSON schema for reference
|
||||
schema = model.model_json_schema()
|
||||
|
||||
feedback = "Your previous response had validation errors:\n\n"
|
||||
feedback += "ERRORS:\n"
|
||||
for error in validation_result.errors:
|
||||
feedback += f" - {error}\n"
|
||||
|
||||
feedback += "\nEXPECTED SCHEMA:\n"
|
||||
feedback += f" Model: {model.__name__}\n"
|
||||
|
||||
if "properties" in schema:
|
||||
feedback += " Required fields:\n"
|
||||
required = schema.get("required", [])
|
||||
for prop_name, prop_info in schema["properties"].items():
|
||||
req_marker = " (required)" if prop_name in required else ""
|
||||
prop_type = prop_info.get("type", "any")
|
||||
feedback += f" - {prop_name}: {prop_type}{req_marker}\n"
|
||||
|
||||
feedback += "\nPlease fix the errors and respond with valid JSON matching the schema."
|
||||
|
||||
return feedback
|
||||
|
||||
def validate_no_hallucination(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
@@ -93,16 +228,10 @@ class OutputValidator:
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
|
||||
# Check for Python-like code
|
||||
code_indicators = [
|
||||
"def ", "class ", "import ", "from ", "if __name__",
|
||||
"async def ", "await ", "try:", "except:"
|
||||
]
|
||||
if any(indicator in value[:500] for indicator in code_indicators):
|
||||
# Check for code patterns in the entire string, not just first 500 chars
|
||||
if self._contains_code_indicators(value):
|
||||
# Could be legitimate, but warn
|
||||
logger.warning(
|
||||
f"Output key '{key}' may contain code - verify this is expected"
|
||||
)
|
||||
logger.warning(f"Output key '{key}' may contain code - verify this is expected")
|
||||
|
||||
# Check for overly long values
|
||||
if len(value) > max_length:
|
||||
@@ -148,6 +277,7 @@ class OutputValidator:
|
||||
expected_keys: list[str] | None = None,
|
||||
schema: dict[str, Any] | None = None,
|
||||
check_hallucination: bool = True,
|
||||
nullable_keys: list[str] | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Run all applicable validations on output.
|
||||
@@ -157,6 +287,7 @@ class OutputValidator:
|
||||
expected_keys: Optional list of required keys
|
||||
schema: Optional JSON schema
|
||||
check_hallucination: Whether to check for hallucination patterns
|
||||
nullable_keys: Keys that are allowed to be None
|
||||
|
||||
Returns:
|
||||
Combined ValidationResult
|
||||
@@ -165,7 +296,7 @@ class OutputValidator:
|
||||
|
||||
# Validate keys if provided
|
||||
if expected_keys:
|
||||
result = self.validate_output_keys(output, expected_keys)
|
||||
result = self.validate_output_keys(output, expected_keys, nullable_keys=nullable_keys)
|
||||
all_errors.extend(result.errors)
|
||||
|
||||
# Validate schema if provided
|
||||
|
||||
@@ -10,20 +10,24 @@ appropriate executor based on action type:
|
||||
- Code execution (sandboxed)
|
||||
"""
|
||||
|
||||
from typing import Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.code_sandbox import CodeSandbox
|
||||
from framework.graph.plan import (
|
||||
PlanStep,
|
||||
ActionSpec,
|
||||
ActionType,
|
||||
PlanStep,
|
||||
)
|
||||
from framework.graph.code_sandbox import CodeSandbox
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
|
||||
@@ -50,7 +54,7 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
|
||||
|
||||
# Try to extract JSON from markdown code blocks
|
||||
# Pattern: ```json ... ``` or ``` ... ```
|
||||
code_block_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
|
||||
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
|
||||
matches = re.findall(code_block_pattern, cleaned)
|
||||
|
||||
if matches:
|
||||
@@ -59,34 +63,46 @@ def parse_llm_json_response(text: str) -> tuple[Any | None, str]:
|
||||
try:
|
||||
parsed = json.loads(match.strip())
|
||||
return parsed, match.strip()
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"Failed to parse JSON from code block: {e}. "
|
||||
f"Content preview: {match.strip()[:100]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
# No code blocks or parsing failed - try parsing the whole response
|
||||
try:
|
||||
parsed = json.loads(cleaned)
|
||||
return parsed, cleaned
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"Failed to parse entire response as JSON: {e}. Content preview: {cleaned[:100]}..."
|
||||
)
|
||||
|
||||
# Try to find JSON-like content (starts with { or [)
|
||||
json_start_pattern = r'(\{[\s\S]*\}|\[[\s\S]*\])'
|
||||
json_start_pattern = r"(\{[\s\S]*\}|\[[\s\S]*\])"
|
||||
json_matches = re.findall(json_start_pattern, cleaned)
|
||||
|
||||
for match in json_matches:
|
||||
try:
|
||||
parsed = json.loads(match)
|
||||
return parsed, match
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(f"Failed to parse JSON pattern: {e}. Content preview: {match[:100]}...")
|
||||
continue
|
||||
|
||||
# Could not parse as JSON
|
||||
# Could not parse as JSON - log warning
|
||||
logger.warning(
|
||||
f"Could not parse LLM response as JSON after trying all strategies. "
|
||||
f"Response preview: {cleaned[:200]}..."
|
||||
)
|
||||
return None, cleaned
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepExecutionResult:
|
||||
"""Result of executing a plan step."""
|
||||
|
||||
success: bool
|
||||
outputs: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
@@ -160,11 +176,13 @@ class WorkerNode:
|
||||
# Record decision
|
||||
decision_id = self.runtime.decide(
|
||||
intent=f"Execute plan step: {step.description}",
|
||||
options=[{
|
||||
"id": step.action.action_type.value,
|
||||
"description": f"Execute {step.action.action_type.value} action",
|
||||
"action_type": step.action.action_type.value,
|
||||
}],
|
||||
options=[
|
||||
{
|
||||
"id": step.action.action_type.value,
|
||||
"description": f"Execute {step.action.action_type.value} action",
|
||||
"action_type": step.action.action_type.value,
|
||||
}
|
||||
],
|
||||
chosen=step.action.action_type.value,
|
||||
reasoning=f"Step requires {step.action.action_type.value}",
|
||||
context={"step_id": step.id, "inputs": step.inputs},
|
||||
@@ -288,7 +306,7 @@ class WorkerNode:
|
||||
if inputs:
|
||||
context_section = "\n\n--- Context Data ---\n"
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
if isinstance(value, dict | list):
|
||||
context_section += f"{key}: {json.dumps(value, indent=2)}\n"
|
||||
else:
|
||||
context_section += f"{key}: {value}\n"
|
||||
@@ -414,6 +432,7 @@ class WorkerNode:
|
||||
try:
|
||||
# Execute tool via formal executor
|
||||
from framework.llm.provider import ToolUse
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=f"step_{tool_name}",
|
||||
name=tool_name,
|
||||
|
||||
@@ -1,7 +1,26 @@
|
||||
"""LLM provider abstraction."""
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "AnthropicProvider", "LiteLLMProvider"]
|
||||
__all__ = ["LLMProvider", "LLMResponse"]
|
||||
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider # noqa: F401
|
||||
|
||||
__all__.append("AnthropicProvider")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: F401
|
||||
|
||||
__all__.append("LiteLLMProvider")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from framework.llm.mock import MockLLMProvider # noqa: F401
|
||||
|
||||
__all__.append("MockLLMProvider")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Anthropic Claude LLM provider - backward compatible wrapper around LiteLLM."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
|
||||
def _get_api_key_from_credential_manager() -> str | None:
|
||||
@@ -55,7 +56,7 @@ class AnthropicProvider(LLMProvider):
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
|
||||
self._provider = LiteLLMProvider(
|
||||
model=model,
|
||||
api_key=self.api_key,
|
||||
@@ -85,7 +86,7 @@ class AnthropicProvider(LLMProvider):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: callable,
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until Claude produces a final response (via LiteLLM)."""
|
||||
|
||||
@@ -8,11 +8,15 @@ See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
litellm = None # type: ignore[assignment]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
@@ -23,6 +27,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
- OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
|
||||
- Anthropic: claude-3-opus, claude-3-sonnet, claude-3-haiku
|
||||
- Google: gemini-pro, gemini-1.5-pro, gemini-1.5-flash
|
||||
- DeepSeek: deepseek-chat, deepseek-coder, deepseek-reasoner
|
||||
- Mistral: mistral-large, mistral-medium, mistral-small
|
||||
- Groq: llama3-70b, mixtral-8x7b
|
||||
- Local: ollama/llama3, ollama/mistral
|
||||
@@ -38,6 +43,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Google Gemini
|
||||
provider = LiteLLMProvider(model="gemini/gemini-1.5-flash")
|
||||
|
||||
# DeepSeek
|
||||
provider = LiteLLMProvider(model="deepseek/deepseek-chat")
|
||||
|
||||
# Local Ollama
|
||||
provider = LiteLLMProvider(model="ollama/llama3")
|
||||
|
||||
@@ -72,6 +80,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
self.api_base = api_base
|
||||
self.extra_kwargs = kwargs
|
||||
|
||||
if litellm is None:
|
||||
raise ImportError(
|
||||
"LiteLLM is not installed. Please install it with: pip install litellm"
|
||||
)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -90,9 +103,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
# Add JSON mode via prompt engineering (works across all providers)
|
||||
if json_mode:
|
||||
json_instruction = (
|
||||
"\n\nPlease respond with a valid JSON object."
|
||||
)
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
# Append to system message if present, otherwise add as system message
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
@@ -122,7 +133,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
# Make the call
|
||||
response = litellm.completion(**kwargs)
|
||||
response = litellm.completion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
# Extract content
|
||||
content = response.choices[0].message.content or ""
|
||||
@@ -146,8 +157,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: callable,
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until the LLM produces a final response."""
|
||||
# Prepare messages with system prompt
|
||||
@@ -167,7 +179,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": 1024,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
@@ -177,7 +189,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = litellm.completion(**kwargs)
|
||||
response = litellm.completion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
# Track tokens
|
||||
usage = response.usage
|
||||
@@ -201,21 +213,23 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
# Process tool calls.
|
||||
# Add assistant message with tool calls.
|
||||
current_messages.append({
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
})
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
@@ -234,11 +248,13 @@ class LiteLLMProvider(LLMProvider):
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
# Add tool result message
|
||||
current_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
})
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
# Max iterations reached
|
||||
return LLMResponse(
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
"""Mock LLM Provider for testing and structural validation without real LLM calls."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
|
||||
|
||||
class MockLLMProvider(LLMProvider):
|
||||
"""
|
||||
Mock LLM provider for testing agents without making real API calls.
|
||||
|
||||
This provider generates placeholder responses based on the expected output structure,
|
||||
allowing structural validation and graph execution testing without incurring costs
|
||||
or requiring API keys.
|
||||
|
||||
Example:
|
||||
llm = MockLLMProvider()
|
||||
response = llm.complete(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
system="Generate JSON with keys: name, age",
|
||||
json_mode=True
|
||||
)
|
||||
# Returns: {"name": "mock_value", "age": "mock_value"}
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "mock-model"):
|
||||
"""
|
||||
Initialize the mock LLM provider.
|
||||
|
||||
Args:
|
||||
model: Model name to report in responses (default: "mock-model")
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def _extract_output_keys(self, system: str) -> list[str]:
|
||||
"""
|
||||
Extract expected output keys from the system prompt.
|
||||
|
||||
Looks for patterns like:
|
||||
- "output_keys: [key1, key2]"
|
||||
- "keys: key1, key2"
|
||||
- "Generate JSON with keys: key1, key2"
|
||||
|
||||
Args:
|
||||
system: System prompt text
|
||||
|
||||
Returns:
|
||||
List of extracted key names
|
||||
"""
|
||||
keys = []
|
||||
|
||||
# Pattern 1: output_keys: [key1, key2]
|
||||
match = re.search(r"output_keys:\s*\[(.*?)\]", system, re.IGNORECASE)
|
||||
if match:
|
||||
keys_str = match.group(1)
|
||||
keys = [k.strip().strip("\"'") for k in keys_str.split(",")]
|
||||
return keys
|
||||
|
||||
# Pattern 2: "keys: key1, key2" or "Generate JSON with keys: key1, key2"
|
||||
match = re.search(r"(?:keys|with keys):\s*([a-zA-Z0-9_,\s]+)", system, re.IGNORECASE)
|
||||
if match:
|
||||
keys_str = match.group(1)
|
||||
keys = [k.strip() for k in keys_str.split(",") if k.strip()]
|
||||
return keys
|
||||
|
||||
# Pattern 3: Look for JSON schema in system prompt
|
||||
match = re.search(r'\{[^}]*"([a-zA-Z0-9_]+)":\s*', system)
|
||||
if match:
|
||||
# Found at least one key in a JSON-like structure
|
||||
all_matches = re.findall(r'"([a-zA-Z0-9_]+)":\s*', system)
|
||||
if all_matches:
|
||||
return list(set(all_matches))
|
||||
|
||||
return keys
|
||||
|
||||
def _generate_mock_response(
|
||||
self,
|
||||
system: str = "",
|
||||
json_mode: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a mock response based on the system prompt and mode.
|
||||
|
||||
Args:
|
||||
system: System prompt (may contain output key hints)
|
||||
json_mode: If True, generate JSON response
|
||||
|
||||
Returns:
|
||||
Mock response string
|
||||
"""
|
||||
if json_mode:
|
||||
# Try to extract expected keys from system prompt
|
||||
keys = self._extract_output_keys(system)
|
||||
|
||||
if keys:
|
||||
# Generate JSON with the expected keys
|
||||
mock_data = {key: f"mock_{key}_value" for key in keys}
|
||||
return json.dumps(mock_data, indent=2)
|
||||
else:
|
||||
# Fallback: generic mock response
|
||||
return json.dumps({"result": "mock_result_value"}, indent=2)
|
||||
else:
|
||||
# Plain text mock response
|
||||
return "This is a mock response for testing purposes."
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without calling a real LLM.
|
||||
|
||||
Args:
|
||||
messages: Conversation history (ignored in mock mode)
|
||||
system: System prompt (used to extract expected output keys)
|
||||
tools: Available tools (ignored in mock mode)
|
||||
max_tokens: Maximum tokens (ignored in mock mode)
|
||||
response_format: Response format (ignored in mock mode)
|
||||
json_mode: If True, generate JSON response
|
||||
|
||||
Returns:
|
||||
LLMResponse with mock content
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=json_mode)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without tool use.
|
||||
|
||||
In mock mode, we skip tool execution and return a final response immediately.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation (ignored in mock mode)
|
||||
system: System prompt (used to extract expected output keys)
|
||||
tools: Available tools (ignored in mock mode)
|
||||
tool_executor: Tool executor function (ignored in mock mode)
|
||||
max_iterations: Max iterations (ignored in mock mode)
|
||||
|
||||
Returns:
|
||||
LLMResponse with mock content
|
||||
"""
|
||||
# In mock mode, we don't execute tools - just return a final response
|
||||
# Try to generate JSON if the system prompt suggests structured output
|
||||
json_mode = "json" in system.lower() or "output_keys" in system.lower()
|
||||
|
||||
content = self._generate_mock_response(system=system, json_mode=json_mode)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -8,6 +9,7 @@ from typing import Any
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM call."""
|
||||
|
||||
content: str
|
||||
model: str
|
||||
input_tokens: int = 0
|
||||
@@ -19,6 +21,7 @@ class LLMResponse:
|
||||
@dataclass
|
||||
class Tool:
|
||||
"""A tool the LLM can use."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -27,6 +30,7 @@ class Tool:
|
||||
@dataclass
|
||||
class ToolUse:
|
||||
"""A tool call requested by the LLM."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
input: dict[str, Any]
|
||||
@@ -35,6 +39,7 @@ class ToolUse:
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result of executing a tool."""
|
||||
|
||||
tool_use_id: str
|
||||
content: str
|
||||
is_error: bool = False
|
||||
@@ -86,7 +91,7 @@ class LLMProvider(ABC):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: callable,
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,15 @@
|
||||
"""Agent Runner - load and run exported agents."""
|
||||
|
||||
from framework.runner.runner import AgentRunner, AgentInfo, ValidationResult
|
||||
from framework.runner.tool_registry import ToolRegistry, tool
|
||||
from framework.runner.orchestrator import AgentOrchestrator
|
||||
from framework.runner.protocol import (
|
||||
AgentMessage,
|
||||
MessageType,
|
||||
CapabilityLevel,
|
||||
CapabilityResponse,
|
||||
MessageType,
|
||||
OrchestratorResult,
|
||||
)
|
||||
from framework.runner.runner import AgentInfo, AgentRunner, ValidationResult
|
||||
from framework.runner.tool_registry import ToolRegistry, tool
|
||||
|
||||
__all__ = [
|
||||
# Single agent
|
||||
|
||||
+103
-59
@@ -22,12 +22,14 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
help="Path to agent folder (containing agent.json)",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--input", "-i",
|
||||
"--input",
|
||||
"-i",
|
||||
type=str,
|
||||
help="Input context as JSON string",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--input-file", "-f",
|
||||
"--input-file",
|
||||
"-f",
|
||||
type=str,
|
||||
help="Input context from JSON file",
|
||||
)
|
||||
@@ -37,17 +39,20 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
help="Run in mock mode (no real LLM calls)",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--output", "-o",
|
||||
"--output",
|
||||
"-o",
|
||||
type=str,
|
||||
help="Write results to file instead of stdout",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--quiet", "-q",
|
||||
"--quiet",
|
||||
"-q",
|
||||
action="store_true",
|
||||
help="Only output the final result JSON",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="store_true",
|
||||
help="Show detailed execution logs (steps, LLM calls, etc.)",
|
||||
)
|
||||
@@ -113,7 +118,8 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
help="Directory containing agent folders (default: exports)",
|
||||
)
|
||||
dispatch_parser.add_argument(
|
||||
"--input", "-i",
|
||||
"--input",
|
||||
"-i",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input context as JSON string",
|
||||
@@ -124,13 +130,15 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
help="Description of what you want to accomplish",
|
||||
)
|
||||
dispatch_parser.add_argument(
|
||||
"--agents", "-a",
|
||||
"--agents",
|
||||
"-a",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Specific agent names to use (default: all in directory)",
|
||||
)
|
||||
dispatch_parser.add_argument(
|
||||
"--quiet", "-q",
|
||||
"--quiet",
|
||||
"-q",
|
||||
action="store_true",
|
||||
help="Only output the final result JSON",
|
||||
)
|
||||
@@ -170,15 +178,16 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
def cmd_run(args: argparse.Namespace) -> int:
|
||||
"""Run an exported agent."""
|
||||
import logging
|
||||
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Set logging level (quiet by default for cleaner output)
|
||||
if args.quiet:
|
||||
logging.basicConfig(level=logging.ERROR, format='%(message)s')
|
||||
elif getattr(args, 'verbose', False):
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logging.basicConfig(level=logging.ERROR, format="%(message)s")
|
||||
elif getattr(args, "verbose", False):
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARNING, format='%(message)s')
|
||||
logging.basicConfig(level=logging.WARNING, format="%(message)s")
|
||||
|
||||
# Load input context
|
||||
context = {}
|
||||
@@ -211,6 +220,7 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
entry_input_keys = runner.graph.nodes[0].input_keys if runner.graph.nodes else []
|
||||
if "user_id" in entry_input_keys and context.get("user_id") is None:
|
||||
import os
|
||||
|
||||
context["user_id"] = os.environ.get("USER", "default_user")
|
||||
|
||||
if not args.quiet:
|
||||
@@ -279,7 +289,13 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
# If no meaningful key found, show all non-internal keys
|
||||
if not shown:
|
||||
for key, value in result.output.items():
|
||||
if not key.startswith("_") and key not in ["user_id", "request", "memory_loaded", "user_profile", "recent_context"]:
|
||||
if not key.startswith("_") and key not in [
|
||||
"user_id",
|
||||
"request",
|
||||
"memory_loaded",
|
||||
"user_profile",
|
||||
"recent_context",
|
||||
]:
|
||||
if isinstance(value, (dict, list)):
|
||||
print(f"\n{key}:")
|
||||
value_str = json.dumps(value, indent=2, default=str)
|
||||
@@ -311,19 +327,24 @@ def cmd_info(args: argparse.Namespace) -> int:
|
||||
info = runner.info()
|
||||
|
||||
if args.json:
|
||||
print(json.dumps({
|
||||
"name": info.name,
|
||||
"description": info.description,
|
||||
"goal_name": info.goal_name,
|
||||
"goal_description": info.goal_description,
|
||||
"node_count": info.node_count,
|
||||
"nodes": info.nodes,
|
||||
"edges": info.edges,
|
||||
"success_criteria": info.success_criteria,
|
||||
"constraints": info.constraints,
|
||||
"required_tools": info.required_tools,
|
||||
"has_tools_module": info.has_tools_module,
|
||||
}, indent=2))
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"name": info.name,
|
||||
"description": info.description,
|
||||
"goal_name": info.goal_name,
|
||||
"goal_description": info.goal_description,
|
||||
"node_count": info.node_count,
|
||||
"nodes": info.nodes,
|
||||
"edges": info.edges,
|
||||
"success_criteria": info.success_criteria,
|
||||
"constraints": info.constraints,
|
||||
"required_tools": info.required_tools,
|
||||
"has_tools_module": info.has_tools_module,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(f"Agent: {info.name}")
|
||||
print(f"Description: {info.description}")
|
||||
@@ -333,8 +354,8 @@ def cmd_info(args: argparse.Namespace) -> int:
|
||||
print()
|
||||
print(f"Nodes ({info.node_count}):")
|
||||
for node in info.nodes:
|
||||
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get('input_keys') else ""
|
||||
outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get('output_keys') else ""
|
||||
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get("input_keys") else ""
|
||||
outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else ""
|
||||
print(f" - {node['id']}: {node['name']}{inputs}{outputs}")
|
||||
print()
|
||||
print(f"Success Criteria ({len(info.success_criteria)}):")
|
||||
@@ -396,8 +417,9 @@ def cmd_list(args: argparse.Namespace) -> int:
|
||||
|
||||
directory = Path(args.directory)
|
||||
if not directory.exists():
|
||||
print(f"Directory not found: {directory}", file=sys.stderr)
|
||||
return 1
|
||||
# FIX: Handle missing directory gracefully on fresh install
|
||||
print(f"No agents found in {directory}")
|
||||
return 0
|
||||
|
||||
agents = []
|
||||
for path in directory.iterdir():
|
||||
@@ -405,19 +427,25 @@ def cmd_list(args: argparse.Namespace) -> int:
|
||||
try:
|
||||
runner = AgentRunner.load(path)
|
||||
info = runner.info()
|
||||
agents.append({
|
||||
"path": str(path),
|
||||
"name": info.name,
|
||||
"description": info.description[:60] + "..." if len(info.description) > 60 else info.description,
|
||||
"nodes": info.node_count,
|
||||
"tools": len(info.required_tools),
|
||||
})
|
||||
agents.append(
|
||||
{
|
||||
"path": str(path),
|
||||
"name": info.name,
|
||||
"description": info.description[:60] + "..."
|
||||
if len(info.description) > 60
|
||||
else info.description,
|
||||
"nodes": info.node_count,
|
||||
"tools": len(info.required_tools),
|
||||
}
|
||||
)
|
||||
runner.cleanup()
|
||||
except Exception as e:
|
||||
agents.append({
|
||||
"path": str(path),
|
||||
"error": str(e),
|
||||
})
|
||||
agents.append(
|
||||
{
|
||||
"path": str(path),
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
if not agents:
|
||||
print(f"No agents found in {directory}")
|
||||
@@ -540,7 +568,7 @@ def cmd_dispatch(args: argparse.Namespace) -> int:
|
||||
|
||||
def _interactive_approval(request):
|
||||
"""Interactive approval callback for HITL mode."""
|
||||
from framework.graph import ApprovalResult, ApprovalDecision
|
||||
from framework.graph import ApprovalDecision, ApprovalResult
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
@@ -561,6 +589,7 @@ def _interactive_approval(request):
|
||||
print(f"\n[{key}]:")
|
||||
if isinstance(value, (dict, list)):
|
||||
import json
|
||||
|
||||
value_str = json.dumps(value, indent=2, default=str)
|
||||
# Show more content for approval - up to 2000 chars
|
||||
if len(value_str) > 2000:
|
||||
@@ -605,11 +634,14 @@ def _interactive_approval(request):
|
||||
print("Invalid choice. Please enter a, r, s, or x.")
|
||||
|
||||
|
||||
def _format_natural_language_to_json(user_input: str, input_keys: list[str], agent_description: str, session_context: dict = None) -> dict:
|
||||
def _format_natural_language_to_json(
|
||||
user_input: str, input_keys: list[str], agent_description: str, session_context: dict = None
|
||||
) -> dict:
|
||||
"""Use Haiku to convert natural language input to JSON based on agent's input schema."""
|
||||
import anthropic
|
||||
import os
|
||||
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
# Build prompt for Haiku
|
||||
@@ -619,17 +651,22 @@ def _format_natural_language_to_json(user_input: str, input_keys: list[str], age
|
||||
main_field = input_keys[0] if input_keys else "objective"
|
||||
existing_value = session_context.get(main_field, "")
|
||||
|
||||
session_info = f"\n\nExisting {main_field}: \"{existing_value}\"\n\nThe user is providing ADDITIONAL information. Append this new information to the existing {main_field} to create an enriched, more detailed version."
|
||||
session_info = (
|
||||
f'\n\nExisting {main_field}: "{existing_value}"\n\n'
|
||||
f"The user is providing ADDITIONAL information. Append this new "
|
||||
f"information to the existing {main_field} to create an enriched, "
|
||||
"more detailed version."
|
||||
)
|
||||
|
||||
prompt = f"""You are formatting user input for an agent that requires specific input fields.
|
||||
|
||||
Agent: {agent_description}
|
||||
|
||||
Required input fields: {', '.join(input_keys)}{session_info}
|
||||
Required input fields: {", ".join(input_keys)}{session_info}
|
||||
|
||||
User input: {user_input}
|
||||
|
||||
{"If this is a follow-up message, APPEND the new information to the existing field value to create a more complete, detailed version. Do not create new fields." if session_context else ""}
|
||||
{"If this is a follow-up, APPEND new info to the existing field value." if session_context else ""}
|
||||
|
||||
Output ONLY valid JSON, no explanation:"""
|
||||
|
||||
@@ -637,7 +674,7 @@ Output ONLY valid JSON, no explanation:"""
|
||||
message = client.messages.create(
|
||||
model="claude-3-5-haiku-20241022", # Fast and cheap
|
||||
max_tokens=500,
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
|
||||
json_str = message.content[0].text.strip()
|
||||
@@ -661,12 +698,13 @@ Output ONLY valid JSON, no explanation:"""
|
||||
def cmd_shell(args: argparse.Namespace) -> int:
|
||||
"""Start an interactive agent session."""
|
||||
import logging
|
||||
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Configure logging to show runtime visibility
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(message)s', # Simple format for clean output
|
||||
format="%(message)s", # Simple format for clean output
|
||||
)
|
||||
|
||||
agents_dir = Path(args.agents_dir)
|
||||
@@ -690,7 +728,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
return 1
|
||||
|
||||
# Set up approval callback by default (unless --no-approve is set)
|
||||
if not getattr(args, 'no_approve', False):
|
||||
if not getattr(args, "no_approve", False):
|
||||
runner.set_approval_callback(_interactive_approval)
|
||||
print("\n🔔 Human-in-the-loop mode enabled")
|
||||
print(" Steps marked for approval will pause for your review")
|
||||
@@ -748,8 +786,10 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
if user_input == "/nodes":
|
||||
print("\nAgent nodes:")
|
||||
for node in info.nodes:
|
||||
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get('input_keys') else ""
|
||||
outputs = f" [out: {', '.join(node['output_keys'])}]" if node.get('output_keys') else ""
|
||||
inputs = f" [in: {', '.join(node['input_keys'])}]" if node.get("input_keys") else ""
|
||||
outputs = (
|
||||
f" [out: {', '.join(node['output_keys'])}]" if node.get("output_keys") else ""
|
||||
)
|
||||
print(f" {node['id']}: {node['name']}{inputs}{outputs}")
|
||||
print(f" {node['description']}")
|
||||
print()
|
||||
@@ -784,7 +824,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
user_input,
|
||||
entry_input_keys,
|
||||
info.description,
|
||||
session_context=session_memory
|
||||
session_context=session_memory,
|
||||
)
|
||||
print(f"✓ Formatted to: {json.dumps(context)}")
|
||||
except Exception as e:
|
||||
@@ -807,6 +847,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
# Auto-inject user_id if missing (for personal assistant agents)
|
||||
if "user_id" in entry_input_keys and run_context.get("user_id") is None:
|
||||
import os
|
||||
|
||||
run_context["user_id"] = os.environ.get("USER", "default_user")
|
||||
|
||||
# Add conversation history to context if agent expects it
|
||||
@@ -872,12 +913,14 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
session_memory[key] = value
|
||||
|
||||
# Track conversation history
|
||||
conversation_history.append({
|
||||
"input": context,
|
||||
"output": result.output if result.output else {},
|
||||
"status": "success" if result.success else "failed",
|
||||
"paused_at": result.paused_at
|
||||
})
|
||||
conversation_history.append(
|
||||
{
|
||||
"input": context,
|
||||
"output": result.output if result.output else {},
|
||||
"status": "success" if result.success else "failed",
|
||||
"paused_at": result.paused_at,
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
@@ -904,6 +947,7 @@ def _select_agent(agents_dir: Path) -> str | None:
|
||||
for i, agent_path in enumerate(agents, 1):
|
||||
try:
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
runner = AgentRunner.load(agent_path)
|
||||
info = runner.info()
|
||||
desc = info.description[:50] + "..." if len(info.description) > 50 else info.description
|
||||
|
||||
@@ -146,6 +146,7 @@ class MCPClient:
|
||||
|
||||
try:
|
||||
import threading
|
||||
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
# Create server parameters
|
||||
@@ -180,7 +181,10 @@ class MCPClient:
|
||||
|
||||
# Create persistent stdio client context
|
||||
self._stdio_context = stdio_client(server_params)
|
||||
self._read_stream, self._write_stream = await self._stdio_context.__aenter__()
|
||||
(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
) = await self._stdio_context.__aenter__()
|
||||
|
||||
# Create persistent session
|
||||
self._session = ClientSession(self._read_stream, self._write_stream)
|
||||
@@ -215,7 +219,7 @@ class MCPClient:
|
||||
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO (persistent)")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {e}")
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
def _connect_http(self) -> None:
|
||||
"""Connect to MCP server via HTTP transport."""
|
||||
@@ -232,7 +236,9 @@ class MCPClient:
|
||||
try:
|
||||
response = self._http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}")
|
||||
logger.info(
|
||||
f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
@@ -255,7 +261,10 @@ class MCPClient:
|
||||
)
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
logger.info(f"Discovered {len(self._tools)} tools from '{self.config.name}': {list(self._tools.keys())}")
|
||||
tool_names = list(self._tools.keys())
|
||||
logger.info(
|
||||
f"Discovered {len(self._tools)} tools from '{self.config.name}': {tool_names}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to discover tools from '{self.config.name}': {e}")
|
||||
raise
|
||||
@@ -271,11 +280,13 @@ class MCPClient:
|
||||
# Convert tools to dict format
|
||||
tools_list = []
|
||||
for tool in response.tools:
|
||||
tools_list.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema,
|
||||
})
|
||||
tools_list.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema,
|
||||
}
|
||||
)
|
||||
|
||||
return tools_list
|
||||
|
||||
@@ -303,7 +314,7 @@ class MCPClient:
|
||||
|
||||
return data.get("result", {}).get("tools", [])
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to list tools via HTTP: {e}")
|
||||
raise RuntimeError(f"Failed to list tools via HTTP: {e}") from e
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
@@ -353,9 +364,9 @@ class MCPClient:
|
||||
if len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
# Check if it's a text content item
|
||||
if hasattr(content_item, 'text'):
|
||||
if hasattr(content_item, "text"):
|
||||
return content_item.text
|
||||
elif hasattr(content_item, 'data'):
|
||||
elif hasattr(content_item, "data"):
|
||||
return content_item.data
|
||||
return result.content
|
||||
|
||||
@@ -387,7 +398,7 @@ class MCPClient:
|
||||
|
||||
return data.get("result", {}).get("content", [])
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}")
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
|
||||
@@ -72,6 +72,7 @@ class AgentOrchestrator:
|
||||
# Auto-create LLM - LiteLLM auto-detects provider and API key from model name
|
||||
if self._llm is None:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
self._llm = LiteLLMProvider(model=self._model)
|
||||
|
||||
def register(
|
||||
@@ -205,7 +206,7 @@ class AgentOrchestrator:
|
||||
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for agent_name, response in zip(routing.selected_agents, responses):
|
||||
for agent_name, response in zip(routing.selected_agents, responses, strict=False):
|
||||
if isinstance(response, Exception):
|
||||
results[agent_name] = {"error": str(response)}
|
||||
else:
|
||||
@@ -326,7 +327,7 @@ class AgentOrchestrator:
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for name, result in zip(agent_names, results):
|
||||
for name, result in zip(agent_names, results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
responses[name] = AgentMessage(
|
||||
type=MessageType.RESPONSE,
|
||||
@@ -355,7 +356,7 @@ class AgentOrchestrator:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
capabilities = {}
|
||||
for name, result in zip(agent_names, results):
|
||||
for name, result in zip(agent_names, results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
capabilities[name] = CapabilityResponse(
|
||||
agent_name=name,
|
||||
@@ -429,8 +430,7 @@ class AgentOrchestrator:
|
||||
"""Use LLM to decide routing when multiple agents are capable."""
|
||||
|
||||
agents_info = "\n".join(
|
||||
f"- {name}: {cap.reasoning} (confidence: {cap.confidence:.2f})"
|
||||
for name, cap in capable
|
||||
f"- {name}: {cap.reasoning} (confidence: {cap.confidence:.2f})" for name, cap in capable
|
||||
)
|
||||
|
||||
prompt = f"""Multiple agents can handle this request. Decide the best routing.
|
||||
@@ -463,7 +463,8 @@ Respond with JSON only:
|
||||
)
|
||||
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
|
||||
|
||||
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
selected = data.get("selected", [])
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Message protocol for multi-agent communication."""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
|
||||
@@ -2,24 +2,25 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec
|
||||
from framework.graph.edge import AsyncEntryPointSpec, EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.executor import ExecutionResult, GraphExecutor
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.graph.executor import GraphExecutor, ExecutionResult
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.core import Runtime
|
||||
|
||||
# Multi-entry-point runtime imports
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner.protocol import CapabilityResponse, AgentMessage
|
||||
from framework.runner.protocol import AgentMessage, CapabilityResponse
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,6 +88,7 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
|
||||
"on_success": EdgeCondition.ON_SUCCESS,
|
||||
"on_failure": EdgeCondition.ON_FAILURE,
|
||||
"conditional": EdgeCondition.CONDITIONAL,
|
||||
"llm_decide": EdgeCondition.LLM_DECIDE,
|
||||
}
|
||||
edge = EdgeSpec(
|
||||
id=edge_data["id"],
|
||||
@@ -102,16 +104,18 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
|
||||
# Build AsyncEntryPointSpec objects for multi-entry-point support
|
||||
async_entry_points = []
|
||||
for aep_data in graph_data.get("async_entry_points", []):
|
||||
async_entry_points.append(AsyncEntryPointSpec(
|
||||
id=aep_data["id"],
|
||||
name=aep_data.get("name", aep_data["id"]),
|
||||
entry_node=aep_data["entry_node"],
|
||||
trigger_type=aep_data.get("trigger_type", "manual"),
|
||||
trigger_config=aep_data.get("trigger_config", {}),
|
||||
isolation_level=aep_data.get("isolation_level", "shared"),
|
||||
priority=aep_data.get("priority", 0),
|
||||
max_concurrent=aep_data.get("max_concurrent", 10),
|
||||
))
|
||||
async_entry_points.append(
|
||||
AsyncEntryPointSpec(
|
||||
id=aep_data["id"],
|
||||
name=aep_data.get("name", aep_data["id"]),
|
||||
entry_node=aep_data["entry_node"],
|
||||
trigger_type=aep_data.get("trigger_type", "manual"),
|
||||
trigger_config=aep_data.get("trigger_config", {}),
|
||||
isolation_level=aep_data.get("isolation_level", "shared"),
|
||||
priority=aep_data.get("priority", 0),
|
||||
max_concurrent=aep_data.get("max_concurrent", 10),
|
||||
)
|
||||
)
|
||||
|
||||
# Build GraphSpec
|
||||
graph = GraphSpec(
|
||||
@@ -131,27 +135,31 @@ def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
|
||||
)
|
||||
|
||||
# Build Goal
|
||||
from framework.graph.goal import SuccessCriterion, Constraint
|
||||
from framework.graph.goal import Constraint, SuccessCriterion
|
||||
|
||||
success_criteria = []
|
||||
for sc_data in goal_data.get("success_criteria", []):
|
||||
success_criteria.append(SuccessCriterion(
|
||||
id=sc_data["id"],
|
||||
description=sc_data["description"],
|
||||
metric=sc_data.get("metric", ""),
|
||||
target=sc_data.get("target", ""),
|
||||
weight=sc_data.get("weight", 1.0),
|
||||
))
|
||||
success_criteria.append(
|
||||
SuccessCriterion(
|
||||
id=sc_data["id"],
|
||||
description=sc_data["description"],
|
||||
metric=sc_data.get("metric", ""),
|
||||
target=sc_data.get("target", ""),
|
||||
weight=sc_data.get("weight", 1.0),
|
||||
)
|
||||
)
|
||||
|
||||
constraints = []
|
||||
for c_data in goal_data.get("constraints", []):
|
||||
constraints.append(Constraint(
|
||||
id=c_data["id"],
|
||||
description=c_data["description"],
|
||||
constraint_type=c_data.get("constraint_type", "hard"),
|
||||
category=c_data.get("category", "safety"),
|
||||
check=c_data.get("check", ""),
|
||||
))
|
||||
constraints.append(
|
||||
Constraint(
|
||||
id=c_data["id"],
|
||||
description=c_data["description"],
|
||||
constraint_type=c_data.get("constraint_type", "hard"),
|
||||
category=c_data.get("category", "safety"),
|
||||
check=c_data.get("check", ""),
|
||||
)
|
||||
)
|
||||
|
||||
goal = Goal(
|
||||
id=goal_data.get("id", ""),
|
||||
@@ -379,7 +387,8 @@ class AgentRunner:
|
||||
try:
|
||||
self._tool_registry.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to register MCP server '{server_config.get('name', 'unknown')}': {e}")
|
||||
server_name = server_config.get("name", "unknown")
|
||||
print(f"Warning: Failed to register MCP server '{server_name}': {e}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load MCP servers config from {config_path}: {e}")
|
||||
|
||||
@@ -409,13 +418,19 @@ class AgentRunner:
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Create LLM provider (if not mock mode and API key available)
|
||||
# Create LLM provider
|
||||
# Uses LiteLLM which auto-detects the provider from model name
|
||||
if not self.mock_mode:
|
||||
if self.mock_mode:
|
||||
# Use mock LLM for testing without real API calls
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
self._llm = MockLLMProvider(model=self.model)
|
||||
else:
|
||||
# Detect required API key from model name
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
if api_key_env and os.environ.get(api_key_env):
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
self._llm = LiteLLMProvider(model=self.model)
|
||||
elif api_key_env:
|
||||
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
|
||||
@@ -760,7 +775,12 @@ class AgentRunner:
|
||||
entry_node=self.graph.entry_node,
|
||||
terminal_nodes=self.graph.terminal_nodes,
|
||||
success_criteria=[
|
||||
{"id": sc.id, "description": sc.description, "metric": sc.metric, "target": sc.target}
|
||||
{
|
||||
"id": sc.id,
|
||||
"description": sc.description,
|
||||
"metric": sc.metric,
|
||||
"target": sc.target,
|
||||
}
|
||||
for sc in self.goal.success_criteria
|
||||
],
|
||||
constraints=[
|
||||
@@ -810,7 +830,7 @@ class AgentRunner:
|
||||
|
||||
# Check tool credentials (Tier 2)
|
||||
missing_creds = cred_manager.get_missing_for_tools(info.required_tools)
|
||||
for cred_name, spec in missing_creds:
|
||||
for _, spec in missing_creds:
|
||||
missing_credentials.append(spec.env_var)
|
||||
affected_tools = [t for t in info.required_tools if t in spec.tools]
|
||||
tools_str = ", ".join(affected_tools)
|
||||
@@ -820,9 +840,9 @@ class AgentRunner:
|
||||
warnings.append(warning_msg)
|
||||
|
||||
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
|
||||
node_types = list(set(node.node_type for node in self.graph.nodes))
|
||||
node_types = list({node.node_type for node in self.graph.nodes})
|
||||
missing_node_creds = cred_manager.get_missing_for_node_types(node_types)
|
||||
for cred_name, spec in missing_node_creds:
|
||||
for _, spec in missing_node_creds:
|
||||
if spec.env_var not in missing_credentials: # Avoid duplicates
|
||||
missing_credentials.append(spec.env_var)
|
||||
affected_types = [t for t in node_types if t in spec.node_types]
|
||||
@@ -834,11 +854,16 @@ class AgentRunner:
|
||||
except ImportError:
|
||||
# aden_tools not installed - fall back to direct check
|
||||
has_llm_nodes = any(
|
||||
node.node_type in ("llm_generate", "llm_tool_use")
|
||||
for node in self.graph.nodes
|
||||
node.node_type in ("llm_generate", "llm_tool_use") for node in self.graph.nodes
|
||||
)
|
||||
if has_llm_nodes and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
warnings.append("Agent has LLM nodes but ANTHROPIC_API_KEY not set")
|
||||
if has_llm_nodes:
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
if api_key_env and not os.environ.get(api_key_env):
|
||||
if api_key_env not in missing_credentials:
|
||||
missing_credentials.append(api_key_env)
|
||||
warnings.append(
|
||||
f"Agent has LLM nodes but {api_key_env} not set (model: {self.model})"
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
valid=len(errors) == 0,
|
||||
@@ -848,7 +873,9 @@ class AgentRunner:
|
||||
missing_credentials=missing_credentials,
|
||||
)
|
||||
|
||||
async def can_handle(self, request: dict, llm: LLMProvider | None = None) -> "CapabilityResponse":
|
||||
async def can_handle(
|
||||
self, request: dict, llm: LLMProvider | None = None
|
||||
) -> "CapabilityResponse":
|
||||
"""
|
||||
Ask the agent if it can handle this request.
|
||||
|
||||
@@ -861,7 +888,7 @@ class AgentRunner:
|
||||
Returns:
|
||||
CapabilityResponse with level, confidence, and reasoning
|
||||
"""
|
||||
from framework.runner.protocol import CapabilityResponse, CapabilityLevel
|
||||
from framework.runner.protocol import CapabilityLevel, CapabilityResponse
|
||||
|
||||
# Use provided LLM or set up our own
|
||||
eval_llm = llm
|
||||
@@ -918,7 +945,8 @@ Respond with JSON only:
|
||||
|
||||
# Parse response
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
|
||||
|
||||
json_match = re.search(r"\{[^{}]*\}", response.content, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
level_map = {
|
||||
@@ -942,7 +970,7 @@ Respond with JSON only:
|
||||
|
||||
def _keyword_capability_check(self, request: dict) -> "CapabilityResponse":
|
||||
"""Simple keyword-based capability check (fallback when no LLM)."""
|
||||
from framework.runner.protocol import CapabilityResponse, CapabilityLevel
|
||||
from framework.runner.protocol import CapabilityLevel, CapabilityResponse
|
||||
|
||||
info = self.info()
|
||||
request_str = json.dumps(request).lower()
|
||||
|
||||
@@ -4,11 +4,12 @@ import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import Tool, ToolUse, ToolResult
|
||||
from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -142,7 +143,7 @@ class ToolRegistry:
|
||||
|
||||
# Check for TOOLS dict
|
||||
if hasattr(module, "TOOLS"):
|
||||
tools_dict = getattr(module, "TOOLS")
|
||||
tools_dict = module.TOOLS
|
||||
executor_func = getattr(module, "tool_executor", None)
|
||||
|
||||
for name, tool in tools_dict.items():
|
||||
|
||||
@@ -7,15 +7,16 @@ while preserving the goal-driven approach.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.runtime.shared_state import SharedStateManager
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.runtime.execution_stream import ExecutionStream, EntryPointSpec
|
||||
from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.shared_state import SharedStateManager
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -29,10 +30,13 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class AgentRuntimeConfig:
|
||||
"""Configuration for AgentRuntime."""
|
||||
|
||||
max_concurrent_executions: int = 100
|
||||
cache_ttl: float = 60.0
|
||||
batch_interval: float = 0.1
|
||||
max_history: int = 1000
|
||||
execution_result_max: int = 1000
|
||||
execution_result_ttl_seconds: float | None = None
|
||||
|
||||
|
||||
class AgentRuntime:
|
||||
@@ -206,6 +210,8 @@ class AgentRuntime:
|
||||
llm=self._llm,
|
||||
tools=self._tools,
|
||||
tool_executor=self._tool_executor,
|
||||
result_retention_max=self._config.execution_result_max,
|
||||
result_retention_ttl_seconds=self._config.execution_result_ttl_seconds,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
@@ -285,7 +291,9 @@ class AgentRuntime:
|
||||
ExecutionResult or None if timeout
|
||||
"""
|
||||
exec_id = await self.trigger(entry_point_id, input_data, session_state=session_state)
|
||||
stream = self._streams[entry_point_id]
|
||||
stream = self._streams.get(entry_point_id)
|
||||
if stream is None:
|
||||
raise ValueError(f"Entry point '{entry_point_id}' not found")
|
||||
return await stream.wait_for_completion(exec_id, timeout)
|
||||
|
||||
async def get_goal_progress(self) -> dict[str, Any]:
|
||||
@@ -411,6 +419,7 @@ class AgentRuntime:
|
||||
|
||||
# === CONVENIENCE FACTORY ===
|
||||
|
||||
|
||||
def create_agent_runtime(
|
||||
graph: "GraphSpec",
|
||||
goal: "Goal",
|
||||
|
||||
@@ -6,13 +6,14 @@ that Builder can analyze. The agent calls simple methods, and the runtime
|
||||
handles all the structured logging.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.schemas.decision import Decision, Option, Outcome, DecisionType
|
||||
from framework.schemas.decision import Decision, DecisionType, Option, Outcome
|
||||
from framework.schemas.run import Run, RunStatus
|
||||
from framework.storage.backend import FileStorage
|
||||
|
||||
@@ -164,7 +165,7 @@ class Runtime:
|
||||
context: Additional context available when deciding
|
||||
|
||||
Returns:
|
||||
The decision ID (use this to record outcome later), or empty string if no run in progress
|
||||
The decision ID (use to record outcome later), or empty string if no run
|
||||
"""
|
||||
if self._current_run is None:
|
||||
# Gracefully handle case where run ended during exception handling
|
||||
@@ -174,15 +175,17 @@ class Runtime:
|
||||
# Build Option objects
|
||||
option_objects = []
|
||||
for opt in options:
|
||||
option_objects.append(Option(
|
||||
id=opt["id"],
|
||||
description=opt.get("description", ""),
|
||||
action_type=opt.get("action_type", "unknown"),
|
||||
action_params=opt.get("action_params", {}),
|
||||
pros=opt.get("pros", []),
|
||||
cons=opt.get("cons", []),
|
||||
confidence=opt.get("confidence", 0.5),
|
||||
))
|
||||
option_objects.append(
|
||||
Option(
|
||||
id=opt["id"],
|
||||
description=opt.get("description", ""),
|
||||
action_type=opt.get("action_type", "unknown"),
|
||||
action_params=opt.get("action_params", {}),
|
||||
pros=opt.get("pros", []),
|
||||
cons=opt.get("cons", []),
|
||||
confidence=opt.get("confidence", 0.5),
|
||||
)
|
||||
)
|
||||
|
||||
# Create decision
|
||||
decision_id = f"dec_{len(self._current_run.decisions)}"
|
||||
@@ -230,7 +233,9 @@ class Runtime:
|
||||
if self._current_run is None:
|
||||
# Gracefully handle case where run ended during exception handling
|
||||
# This can happen in cascading error scenarios
|
||||
logger.warning(f"record_outcome called but no run in progress (decision_id={decision_id})")
|
||||
logger.warning(
|
||||
f"record_outcome called but no run in progress (decision_id={decision_id})"
|
||||
)
|
||||
return
|
||||
|
||||
outcome = Outcome(
|
||||
@@ -274,7 +279,9 @@ class Runtime:
|
||||
if self._current_run is None:
|
||||
# Gracefully handle case where run ended during exception handling
|
||||
# Log the problem since we can't store it, then return empty ID
|
||||
logger.warning(f"report_problem called but no run in progress: [{severity}] {description}")
|
||||
logger.warning(
|
||||
f"report_problem called but no run in progress: [{severity}] {description}"
|
||||
)
|
||||
return ""
|
||||
|
||||
return self._current_run.add_problem(
|
||||
@@ -293,7 +300,7 @@ class Runtime:
|
||||
options: list[dict[str, Any]],
|
||||
chosen: str,
|
||||
reasoning: str,
|
||||
executor: callable,
|
||||
executor: Callable,
|
||||
**kwargs,
|
||||
) -> tuple[str, Any]:
|
||||
"""
|
||||
@@ -370,11 +377,13 @@ class Runtime:
|
||||
"""
|
||||
return self.decide(
|
||||
intent=intent,
|
||||
options=[{
|
||||
"id": "action",
|
||||
"description": action,
|
||||
"action_type": "execute",
|
||||
}],
|
||||
options=[
|
||||
{
|
||||
"id": "action",
|
||||
"description": action,
|
||||
"action_type": "execute",
|
||||
}
|
||||
],
|
||||
chosen="action",
|
||||
reasoning=reasoning,
|
||||
node_id=node_id,
|
||||
|
||||
@@ -9,11 +9,11 @@ Allows streams to:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,6 +48,7 @@ class EventType(str, Enum):
|
||||
@dataclass
|
||||
class AgentEvent:
|
||||
"""An event in the agent system."""
|
||||
|
||||
type: EventType
|
||||
stream_id: str
|
||||
execution_id: str | None = None
|
||||
@@ -74,6 +75,7 @@ EventHandler = Callable[[AgentEvent], Awaitable[None]]
|
||||
@dataclass
|
||||
class Subscription:
|
||||
"""A subscription to events."""
|
||||
|
||||
id: str
|
||||
event_types: set[EventType]
|
||||
handler: EventHandler
|
||||
@@ -193,7 +195,7 @@ class EventBus:
|
||||
async with self._lock:
|
||||
self._event_history.append(event)
|
||||
if len(self._event_history) > self._max_history:
|
||||
self._event_history = self._event_history[-self._max_history:]
|
||||
self._event_history = self._event_history[-self._max_history :]
|
||||
|
||||
# Find matching subscriptions
|
||||
matching_handlers: list[EventHandler] = []
|
||||
@@ -249,13 +251,15 @@ class EventBus:
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution started event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"input": input_data or {}},
|
||||
correlation_id=correlation_id,
|
||||
))
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"input": input_data or {}},
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_completed(
|
||||
self,
|
||||
@@ -265,13 +269,15 @@ class EventBus:
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution completed event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"output": output or {}},
|
||||
correlation_id=correlation_id,
|
||||
))
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"output": output or {}},
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_execution_failed(
|
||||
self,
|
||||
@@ -281,13 +287,15 @@ class EventBus:
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit execution failed event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_FAILED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"error": error},
|
||||
correlation_id=correlation_id,
|
||||
))
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_FAILED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={"error": error},
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_goal_progress(
|
||||
self,
|
||||
@@ -296,14 +304,16 @@ class EventBus:
|
||||
criteria_status: dict[str, Any],
|
||||
) -> None:
|
||||
"""Emit goal progress event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.GOAL_PROGRESS,
|
||||
stream_id=stream_id,
|
||||
data={
|
||||
"progress": progress,
|
||||
"criteria_status": criteria_status,
|
||||
},
|
||||
))
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.GOAL_PROGRESS,
|
||||
stream_id=stream_id,
|
||||
data={
|
||||
"progress": progress,
|
||||
"criteria_status": criteria_status,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_constraint_violation(
|
||||
self,
|
||||
@@ -313,15 +323,17 @@ class EventBus:
|
||||
description: str,
|
||||
) -> None:
|
||||
"""Emit constraint violation event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.CONSTRAINT_VIOLATION,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"constraint_id": constraint_id,
|
||||
"description": description,
|
||||
},
|
||||
))
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONSTRAINT_VIOLATION,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"constraint_id": constraint_id,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_state_changed(
|
||||
self,
|
||||
@@ -333,17 +345,19 @@ class EventBus:
|
||||
scope: str,
|
||||
) -> None:
|
||||
"""Emit state changed event."""
|
||||
await self.publish(AgentEvent(
|
||||
type=EventType.STATE_CHANGED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"key": key,
|
||||
"old_value": old_value,
|
||||
"new_value": new_value,
|
||||
"scope": scope,
|
||||
},
|
||||
))
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.STATE_CHANGED,
|
||||
stream_id=stream_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"key": key,
|
||||
"old_value": old_value,
|
||||
"new_value": new_value,
|
||||
"scope": scope,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
@@ -432,7 +446,7 @@ class EventBus:
|
||||
if timeout:
|
||||
try:
|
||||
await asyncio.wait_for(event_received.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
return None
|
||||
else:
|
||||
await event_received.wait()
|
||||
|
||||
@@ -9,22 +9,25 @@ Each stream has:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.executor import GraphExecutor, ExecutionResult
|
||||
from framework.graph.executor import ExecutionResult, GraphExecutor
|
||||
from framework.runtime.shared_state import IsolationLevel, SharedStateManager
|
||||
from framework.runtime.stream_runtime import StreamRuntime, StreamRuntimeAdapter
|
||||
from framework.runtime.shared_state import SharedStateManager, IsolationLevel, StreamMemory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,6 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class EntryPointSpec:
|
||||
"""Specification for an entry point."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
entry_node: str # Node ID to start from
|
||||
@@ -49,6 +53,7 @@ class EntryPointSpec:
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""Context for a single execution."""
|
||||
|
||||
id: str
|
||||
correlation_id: str
|
||||
stream_id: str
|
||||
@@ -105,6 +110,8 @@ class ExecutionStream:
|
||||
llm: "LLMProvider | None" = None,
|
||||
tools: list["Tool"] | None = None,
|
||||
tool_executor: Callable | None = None,
|
||||
result_retention_max: int | None = 1000,
|
||||
result_retention_ttl_seconds: float | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -133,6 +140,8 @@ class ExecutionStream:
|
||||
self._llm = llm
|
||||
self._tools = tools or []
|
||||
self._tool_executor = tool_executor
|
||||
self._result_retention_max = result_retention_max
|
||||
self._result_retention_ttl_seconds = result_retention_ttl_seconds
|
||||
|
||||
# Create stream-scoped runtime
|
||||
self._runtime = StreamRuntime(
|
||||
@@ -144,7 +153,8 @@ class ExecutionStream:
|
||||
# Execution tracking
|
||||
self._active_executions: dict[str, ExecutionContext] = {}
|
||||
self._execution_tasks: dict[str, asyncio.Task] = {}
|
||||
self._execution_results: dict[str, ExecutionResult] = {}
|
||||
self._execution_results: OrderedDict[str, ExecutionResult] = OrderedDict()
|
||||
self._execution_result_times: dict[str, float] = {}
|
||||
self._completion_events: dict[str, asyncio.Event] = {}
|
||||
|
||||
# Concurrency control
|
||||
@@ -164,12 +174,36 @@ class ExecutionStream:
|
||||
|
||||
# Emit stream started event
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import EventType, AgentEvent
|
||||
await self._event_bus.publish(AgentEvent(
|
||||
type=EventType.STREAM_STARTED,
|
||||
stream_id=self.stream_id,
|
||||
data={"entry_point": self.entry_spec.id},
|
||||
))
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.STREAM_STARTED,
|
||||
stream_id=self.stream_id,
|
||||
data={"entry_point": self.entry_spec.id},
|
||||
)
|
||||
)
|
||||
|
||||
def _record_execution_result(self, execution_id: str, result: ExecutionResult) -> None:
|
||||
"""Record a completed execution result with retention pruning."""
|
||||
self._execution_results[execution_id] = result
|
||||
self._execution_results.move_to_end(execution_id)
|
||||
self._execution_result_times[execution_id] = time.time()
|
||||
self._prune_execution_results()
|
||||
|
||||
def _prune_execution_results(self) -> None:
|
||||
"""Prune completed results based on TTL and max retention."""
|
||||
if self._result_retention_ttl_seconds is not None:
|
||||
cutoff = time.time() - self._result_retention_ttl_seconds
|
||||
for exec_id, recorded_at in list(self._execution_result_times.items()):
|
||||
if recorded_at < cutoff:
|
||||
self._execution_result_times.pop(exec_id, None)
|
||||
self._execution_results.pop(exec_id, None)
|
||||
|
||||
if self._result_retention_max is not None:
|
||||
while len(self._execution_results) > self._result_retention_max:
|
||||
old_exec_id, _ = self._execution_results.popitem(last=False)
|
||||
self._execution_result_times.pop(old_exec_id, None)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the execution stream and cancel active executions."""
|
||||
@@ -179,7 +213,7 @@ class ExecutionStream:
|
||||
self._running = False
|
||||
|
||||
# Cancel all active executions
|
||||
for exec_id, task in self._execution_tasks.items():
|
||||
for _, task in self._execution_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
@@ -194,11 +228,14 @@ class ExecutionStream:
|
||||
|
||||
# Emit stream stopped event
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import EventType, AgentEvent
|
||||
await self._event_bus.publish(AgentEvent(
|
||||
type=EventType.STREAM_STOPPED,
|
||||
stream_id=self.stream_id,
|
||||
))
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.STREAM_STOPPED,
|
||||
stream_id=self.stream_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@@ -268,7 +305,7 @@ class ExecutionStream:
|
||||
)
|
||||
|
||||
# Create execution-scoped memory
|
||||
memory = self._state_manager.create_memory(
|
||||
self._state_manager.create_memory(
|
||||
execution_id=execution_id,
|
||||
stream_id=self.stream_id,
|
||||
isolation=ctx.isolation_level,
|
||||
@@ -297,8 +334,8 @@ class ExecutionStream:
|
||||
session_state=ctx.session_state,
|
||||
)
|
||||
|
||||
# Store result
|
||||
self._execution_results[execution_id] = result
|
||||
# Store result with retention
|
||||
self._record_execution_result(execution_id, result)
|
||||
|
||||
# Update context
|
||||
ctx.completed_at = datetime.now()
|
||||
@@ -333,10 +370,13 @@ class ExecutionStream:
|
||||
ctx.status = "failed"
|
||||
logger.error(f"Execution {execution_id} failed: {e}")
|
||||
|
||||
# Store error result
|
||||
self._execution_results[execution_id] = ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
# Store error result with retention
|
||||
self._record_execution_result(
|
||||
execution_id,
|
||||
ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit failure event
|
||||
@@ -356,6 +396,12 @@ class ExecutionStream:
|
||||
if execution_id in self._completion_events:
|
||||
self._completion_events[execution_id].set()
|
||||
|
||||
# Remove in-flight bookkeeping
|
||||
async with self._lock:
|
||||
self._active_executions.pop(execution_id, None)
|
||||
self._completion_events.pop(execution_id, None)
|
||||
self._execution_tasks.pop(execution_id, None)
|
||||
|
||||
def _create_modified_graph(self) -> "GraphSpec":
|
||||
"""Create a graph with the entry point overridden."""
|
||||
# Use the existing graph but override entry_node
|
||||
@@ -378,6 +424,7 @@ class ExecutionStream:
|
||||
default_model=self.graph.default_model,
|
||||
max_tokens=self.graph.max_tokens,
|
||||
max_steps=self.graph.max_steps,
|
||||
cleanup_llm_model=self.graph.cleanup_llm_model,
|
||||
)
|
||||
|
||||
async def wait_for_completion(
|
||||
@@ -398,6 +445,7 @@ class ExecutionStream:
|
||||
event = self._completion_events.get(execution_id)
|
||||
if event is None:
|
||||
# Execution not found or already cleaned up
|
||||
self._prune_execution_results()
|
||||
return self._execution_results.get(execution_id)
|
||||
|
||||
try:
|
||||
@@ -406,13 +454,15 @@ class ExecutionStream:
|
||||
else:
|
||||
await event.wait()
|
||||
|
||||
self._prune_execution_results()
|
||||
return self._execution_results.get(execution_id)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
def get_result(self, execution_id: str) -> ExecutionResult | None:
|
||||
"""Get result of a completed execution."""
|
||||
self._prune_execution_results()
|
||||
return self._execution_results.get(execution_id)
|
||||
|
||||
def get_context(self, execution_id: str) -> ExecutionContext | None:
|
||||
@@ -443,10 +493,7 @@ class ExecutionStream:
|
||||
|
||||
def get_active_count(self) -> int:
|
||||
"""Get count of active executions."""
|
||||
return len([
|
||||
ctx for ctx in self._active_executions.values()
|
||||
if ctx.status == "running"
|
||||
])
|
||||
return len([ctx for ctx in self._active_executions.values() if ctx.status == "running"])
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get stream statistics."""
|
||||
@@ -454,6 +501,10 @@ class ExecutionStream:
|
||||
for ctx in self._active_executions.values():
|
||||
statuses[ctx.status] = statuses.get(ctx.status, 0) + 1
|
||||
|
||||
# Calculate available slots from running count instead of accessing private _value
|
||||
running_count = statuses.get("running", 0)
|
||||
available_slots = self.entry_spec.max_concurrent - running_count
|
||||
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"entry_point": self.entry_spec.id,
|
||||
@@ -462,5 +513,5 @@ class ExecutionStream:
|
||||
"completed_executions": len(self._execution_results),
|
||||
"status_counts": statuses,
|
||||
"max_concurrent": self.entry_spec.max_concurrent,
|
||||
"available_slots": self._semaphore._value,
|
||||
"available_slots": available_slots,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.schemas.decision import Decision, Outcome
|
||||
|
||||
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class CriterionStatus:
|
||||
"""Status of a success criterion."""
|
||||
|
||||
criterion_id: str
|
||||
description: str
|
||||
met: bool
|
||||
@@ -34,6 +35,7 @@ class CriterionStatus:
|
||||
@dataclass
|
||||
class ConstraintCheck:
|
||||
"""Result of a constraint check."""
|
||||
|
||||
constraint_id: str
|
||||
description: str
|
||||
violated: bool
|
||||
@@ -46,6 +48,7 @@ class ConstraintCheck:
|
||||
@dataclass
|
||||
class DecisionRecord:
|
||||
"""Record of a decision for aggregation."""
|
||||
|
||||
stream_id: str
|
||||
execution_id: str
|
||||
decision: Decision
|
||||
@@ -284,10 +287,11 @@ class OutcomeAggregator:
|
||||
"successful_outcomes": self._successful_outcomes,
|
||||
"failed_outcomes": self._failed_outcomes,
|
||||
"success_rate": (
|
||||
self._successful_outcomes / max(1, self._successful_outcomes + self._failed_outcomes)
|
||||
self._successful_outcomes
|
||||
/ max(1, self._successful_outcomes + self._failed_outcomes)
|
||||
),
|
||||
"streams_active": len(set(d.stream_id for d in self._decisions)),
|
||||
"executions_total": len(set((d.stream_id, d.execution_id) for d in self._decisions)),
|
||||
"streams_active": len({d.stream_id for d in self._decisions}),
|
||||
"executions_total": len({(d.stream_id, d.execution_id) for d in self._decisions}),
|
||||
}
|
||||
|
||||
# Determine recommendation
|
||||
@@ -296,7 +300,7 @@ class OutcomeAggregator:
|
||||
# Publish progress event
|
||||
if self._event_bus:
|
||||
# Get any stream ID for the event
|
||||
stream_ids = set(d.stream_id for d in self._decisions)
|
||||
stream_ids = {d.stream_id for d in self._decisions}
|
||||
if stream_ids:
|
||||
await self._event_bus.emit_goal_progress(
|
||||
stream_id=list(stream_ids)[0],
|
||||
@@ -323,7 +327,8 @@ class OutcomeAggregator:
|
||||
|
||||
# Get relevant decisions (those mentioning this criterion or related intents)
|
||||
relevant_decisions = [
|
||||
d for d in self._decisions
|
||||
d
|
||||
for d in self._decisions
|
||||
if criterion.id in str(d.decision.active_constraints)
|
||||
or self._is_related_to_criterion(d.decision, criterion)
|
||||
]
|
||||
@@ -341,7 +346,9 @@ class OutcomeAggregator:
|
||||
# Add evidence
|
||||
for d in relevant_decisions[:5]: # Limit evidence
|
||||
if d.outcome:
|
||||
evidence = f"{d.decision.intent}: {'success' if d.outcome.success else 'failed'}"
|
||||
evidence = (
|
||||
f"{d.decision.intent}: {'success' if d.outcome.success else 'failed'}"
|
||||
)
|
||||
status.evidence.append(evidence)
|
||||
|
||||
# Check if criterion is met based on target
|
||||
@@ -373,10 +380,7 @@ class OutcomeAggregator:
|
||||
violations = result["constraint_violations"]
|
||||
|
||||
# Check for hard constraint violations
|
||||
hard_violations = [
|
||||
v for v in violations
|
||||
if self._is_hard_constraint(v["constraint_id"])
|
||||
]
|
||||
hard_violations = [v for v in violations if self._is_hard_constraint(v["constraint_id"])]
|
||||
|
||||
if hard_violations:
|
||||
return "adjust" # Must address violations
|
||||
@@ -409,7 +413,8 @@ class OutcomeAggregator:
|
||||
) -> list[DecisionRecord]:
|
||||
"""Get all decisions from a specific execution."""
|
||||
return [
|
||||
d for d in self._decisions
|
||||
d
|
||||
for d in self._decisions
|
||||
if d.stream_id == stream_id and d.execution_id == execution_id
|
||||
]
|
||||
|
||||
@@ -429,7 +434,7 @@ class OutcomeAggregator:
|
||||
"failed_outcomes": self._failed_outcomes,
|
||||
"constraint_violations": len(self._constraint_violations),
|
||||
"criteria_tracked": len(self._criterion_status),
|
||||
"streams_seen": len(set(d.stream_id for d in self._decisions)),
|
||||
"streams_seen": len({d.stream_id for d in self._decisions}),
|
||||
}
|
||||
|
||||
# === RESET OPERATIONS ===
|
||||
|
||||
@@ -19,21 +19,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class IsolationLevel(str, Enum):
|
||||
"""State isolation level for concurrent executions."""
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
SHARED = "shared" # Shared state (eventual consistency)
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
|
||||
ISOLATED = "isolated" # Private state per execution
|
||||
SHARED = "shared" # Shared state (eventual consistency)
|
||||
SYNCHRONIZED = "synchronized" # Shared with write locks (strong consistency)
|
||||
|
||||
|
||||
class StateScope(str, Enum):
|
||||
"""Scope for state operations."""
|
||||
EXECUTION = "execution" # Local to a single execution
|
||||
STREAM = "stream" # Shared within a stream
|
||||
GLOBAL = "global" # Shared across all streams
|
||||
|
||||
EXECUTION = "execution" # Local to a single execution
|
||||
STREAM = "stream" # Shared within a stream
|
||||
GLOBAL = "global" # Shared across all streams
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateChange:
|
||||
"""Record of a state change."""
|
||||
|
||||
key: str
|
||||
old_value: Any
|
||||
new_value: Any
|
||||
@@ -212,14 +215,16 @@ class SharedStateManager:
|
||||
await self._write_direct(key, value, execution_id, stream_id, scope)
|
||||
|
||||
# Record change
|
||||
self._record_change(StateChange(
|
||||
key=key,
|
||||
old_value=old_value,
|
||||
new_value=value,
|
||||
scope=scope,
|
||||
execution_id=execution_id,
|
||||
stream_id=stream_id,
|
||||
))
|
||||
self._record_change(
|
||||
StateChange(
|
||||
key=key,
|
||||
old_value=old_value,
|
||||
new_value=value,
|
||||
scope=scope,
|
||||
execution_id=execution_id,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _write_direct(
|
||||
self,
|
||||
@@ -278,7 +283,7 @@ class SharedStateManager:
|
||||
|
||||
# Trim history if too long
|
||||
if len(self._change_history) > self._max_history:
|
||||
self._change_history = self._change_history[-self._max_history:]
|
||||
self._change_history = self._change_history[-self._max_history :]
|
||||
|
||||
# === BULK OPERATIONS ===
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.schemas.decision import Decision, Option, Outcome, DecisionType
|
||||
from framework.schemas.decision import Decision, DecisionType, Option, Outcome
|
||||
from framework.schemas.run import Run, RunStatus
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
@@ -117,7 +117,8 @@ class StreamRuntime:
|
||||
Returns:
|
||||
The run ID
|
||||
"""
|
||||
run_id = f"run_{self.stream_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
run_id = f"run_{self.stream_id}_{timestamp}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
run = Run(
|
||||
id=run_id,
|
||||
@@ -130,7 +131,9 @@ class StreamRuntime:
|
||||
self._run_locks[execution_id] = asyncio.Lock()
|
||||
self._current_nodes[execution_id] = "unknown"
|
||||
|
||||
logger.debug(f"Started run {run_id} for execution {execution_id} in stream {self.stream_id}")
|
||||
logger.debug(
|
||||
f"Started run {run_id} for execution {execution_id} in stream {self.stream_id}"
|
||||
)
|
||||
return run_id
|
||||
|
||||
def end_run(
|
||||
@@ -224,15 +227,17 @@ class StreamRuntime:
|
||||
# Build Option objects
|
||||
option_objects = []
|
||||
for opt in options:
|
||||
option_objects.append(Option(
|
||||
id=opt["id"],
|
||||
description=opt.get("description", ""),
|
||||
action_type=opt.get("action_type", "unknown"),
|
||||
action_params=opt.get("action_params", {}),
|
||||
pros=opt.get("pros", []),
|
||||
cons=opt.get("cons", []),
|
||||
confidence=opt.get("confidence", 0.5),
|
||||
))
|
||||
option_objects.append(
|
||||
Option(
|
||||
id=opt["id"],
|
||||
description=opt.get("description", ""),
|
||||
action_type=opt.get("action_type", "unknown"),
|
||||
action_params=opt.get("action_params", {}),
|
||||
pros=opt.get("pros", []),
|
||||
cons=opt.get("cons", []),
|
||||
confidence=opt.get("confidence", 0.5),
|
||||
)
|
||||
)
|
||||
|
||||
# Create decision
|
||||
decision_id = f"dec_{len(run.decisions)}"
|
||||
@@ -341,7 +346,10 @@ class StreamRuntime:
|
||||
"""
|
||||
run = self._runs.get(execution_id)
|
||||
if run is None:
|
||||
logger.warning(f"report_problem called but no run for execution {execution_id}: [{severity}] {description}")
|
||||
logger.warning(
|
||||
f"report_problem called but no run for execution {execution_id}: "
|
||||
f"[{severity}] {description}"
|
||||
)
|
||||
return ""
|
||||
|
||||
return run.add_problem(
|
||||
@@ -377,11 +385,13 @@ class StreamRuntime:
|
||||
return self.decide(
|
||||
execution_id=execution_id,
|
||||
intent=intent,
|
||||
options=[{
|
||||
"id": "action",
|
||||
"description": action,
|
||||
"action_type": "execute",
|
||||
}],
|
||||
options=[
|
||||
{
|
||||
"id": "action",
|
||||
"description": action,
|
||||
"action_type": "execute",
|
||||
}
|
||||
],
|
||||
chosen="action",
|
||||
reasoning=reasoning,
|
||||
node_id=node_id,
|
||||
|
||||
@@ -11,24 +11,24 @@ Tests:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import Goal
|
||||
from framework.graph.goal import SuccessCriterion, Constraint
|
||||
from framework.graph.edge import GraphSpec, EdgeSpec, EdgeCondition, AsyncEntryPointSpec
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.shared_state import SharedStateManager, IsolationLevel
|
||||
from framework.runtime.event_bus import EventBus, EventType, AgentEvent
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.stream_runtime import StreamRuntime
|
||||
import pytest
|
||||
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import AsyncEntryPointSpec, EdgeCondition, EdgeSpec, GraphSpec
|
||||
from framework.graph.goal import Constraint, SuccessCriterion
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.shared_state import IsolationLevel, SharedStateManager
|
||||
|
||||
# === Test Fixtures ===
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_goal():
|
||||
"""Create a sample goal for testing."""
|
||||
@@ -141,6 +141,7 @@ def temp_storage():
|
||||
|
||||
# === SharedStateManager Tests ===
|
||||
|
||||
|
||||
class TestSharedStateManager:
|
||||
"""Tests for SharedStateManager."""
|
||||
|
||||
@@ -175,8 +176,8 @@ class TestSharedStateManager:
|
||||
"""Test shared state is visible across executions."""
|
||||
manager = SharedStateManager()
|
||||
|
||||
mem1 = manager.create_memory("exec-1", "stream-1", IsolationLevel.SHARED)
|
||||
mem2 = manager.create_memory("exec-2", "stream-1", IsolationLevel.SHARED)
|
||||
manager.create_memory("exec-1", "stream-1", IsolationLevel.SHARED)
|
||||
manager.create_memory("exec-2", "stream-1", IsolationLevel.SHARED)
|
||||
|
||||
# Write to global scope
|
||||
await manager.write(
|
||||
@@ -209,6 +210,7 @@ class TestSharedStateManager:
|
||||
|
||||
# === EventBus Tests ===
|
||||
|
||||
|
||||
class TestEventBus:
|
||||
"""Tests for EventBus pub/sub."""
|
||||
|
||||
@@ -226,12 +228,14 @@ class TestEventBus:
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec-1",
|
||||
data={"test": "data"},
|
||||
))
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec-1",
|
||||
data={"test": "data"},
|
||||
)
|
||||
)
|
||||
|
||||
# Allow handler to run
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -256,16 +260,20 @@ class TestEventBus:
|
||||
)
|
||||
|
||||
# Publish to webhook stream (should be received)
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="webhook",
|
||||
))
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="webhook",
|
||||
)
|
||||
)
|
||||
|
||||
# Publish to api stream (should NOT be received)
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="api",
|
||||
))
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="api",
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -308,11 +316,13 @@ class TestEventBus:
|
||||
|
||||
# Publish the event
|
||||
await asyncio.sleep(0.1)
|
||||
await bus.publish(AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec-1",
|
||||
))
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="webhook",
|
||||
execution_id="exec-1",
|
||||
)
|
||||
)
|
||||
|
||||
event = await wait_task
|
||||
|
||||
@@ -322,6 +332,7 @@ class TestEventBus:
|
||||
|
||||
# === OutcomeAggregator Tests ===
|
||||
|
||||
|
||||
class TestOutcomeAggregator:
|
||||
"""Tests for OutcomeAggregator."""
|
||||
|
||||
@@ -376,6 +387,7 @@ class TestOutcomeAggregator:
|
||||
|
||||
# === AgentRuntime Tests ===
|
||||
|
||||
|
||||
class TestAgentRuntime:
|
||||
"""Tests for AgentRuntime orchestration."""
|
||||
|
||||
@@ -491,6 +503,7 @@ class TestAgentRuntime:
|
||||
|
||||
# === GraphSpec Validation Tests ===
|
||||
|
||||
|
||||
class TestGraphSpecValidation:
|
||||
"""Tests for GraphSpec with async_entry_points."""
|
||||
|
||||
@@ -595,6 +608,7 @@ class TestGraphSpecValidation:
|
||||
|
||||
# === Integration Tests ===
|
||||
|
||||
|
||||
class TestCreateAgentRuntime:
|
||||
"""Tests for the create_agent_runtime factory."""
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Schema definitions for runtime data."""
|
||||
|
||||
from framework.schemas.decision import Decision, Option, Outcome, DecisionEvaluation
|
||||
from framework.schemas.run import Run, RunSummary, Problem
|
||||
from framework.schemas.decision import Decision, DecisionEvaluation, Option, Outcome
|
||||
from framework.schemas.run import Problem, Run, RunSummary
|
||||
|
||||
__all__ = [
|
||||
"Decision",
|
||||
|
||||
@@ -10,22 +10,23 @@ This is MORE important than actions because:
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
"""Types of decisions an agent can make."""
|
||||
TOOL_SELECTION = "tool_selection" # Which tool to use
|
||||
|
||||
TOOL_SELECTION = "tool_selection" # Which tool to use
|
||||
PARAMETER_CHOICE = "parameter_choice" # What parameters to pass
|
||||
PATH_CHOICE = "path_choice" # Which branch to take
|
||||
OUTPUT_FORMAT = "output_format" # How to format output
|
||||
RETRY_STRATEGY = "retry_strategy" # How to handle failure
|
||||
DELEGATION = "delegation" # Whether to delegate to another node
|
||||
TERMINATION = "termination" # Whether to stop or continue
|
||||
CUSTOM = "custom" # User-defined decision type
|
||||
PATH_CHOICE = "path_choice" # Which branch to take
|
||||
OUTPUT_FORMAT = "output_format" # How to format output
|
||||
RETRY_STRATEGY = "retry_strategy" # How to handle failure
|
||||
DELEGATION = "delegation" # Whether to delegate to another node
|
||||
TERMINATION = "termination" # Whether to stop or continue
|
||||
CUSTOM = "custom" # User-defined decision type
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
@@ -35,9 +36,10 @@ class Option(BaseModel):
|
||||
Capturing options is crucial - it shows what the agent considered
|
||||
and enables us to evaluate whether the right choice was made.
|
||||
"""
|
||||
|
||||
id: str
|
||||
description: str # Human-readable: "Call search API"
|
||||
action_type: str # "tool_call", "generate", "delegate"
|
||||
description: str # Human-readable: "Call search API"
|
||||
action_type: str # "tool_call", "generate", "delegate"
|
||||
action_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Why might this be good or bad?
|
||||
@@ -57,9 +59,10 @@ class Outcome(BaseModel):
|
||||
This is filled in AFTER the action completes, allowing us to
|
||||
correlate decisions with their results.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
result: Any = None # The actual output
|
||||
error: str | None = None # Error message if failed
|
||||
result: Any = None # The actual output
|
||||
error: str | None = None # Error message if failed
|
||||
|
||||
# Side effects
|
||||
state_changes: dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -67,7 +70,7 @@ class Outcome(BaseModel):
|
||||
latency_ms: int = 0
|
||||
|
||||
# Natural language summary (crucial for Builder)
|
||||
summary: str = "" # "Found 3 contacts matching query"
|
||||
summary: str = "" # "Found 3 contacts matching query"
|
||||
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@@ -81,6 +84,7 @@ class DecisionEvaluation(BaseModel):
|
||||
This is computed AFTER the run completes, allowing us to
|
||||
judge decisions in light of their eventual outcomes.
|
||||
"""
|
||||
|
||||
# Did it move toward the goal?
|
||||
goal_aligned: bool = True
|
||||
alignment_score: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
@@ -109,6 +113,7 @@ class Decision(BaseModel):
|
||||
Every significant choice the agent makes is captured here.
|
||||
This is the core data structure for understanding and improving agents.
|
||||
"""
|
||||
|
||||
id: str
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
node_id: str
|
||||
|
||||
@@ -6,8 +6,8 @@ summaries and metrics that Builder needs to understand what happened.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
@@ -16,10 +16,11 @@ from framework.schemas.decision import Decision, Outcome
|
||||
|
||||
class RunStatus(str, Enum):
|
||||
"""Status of a run."""
|
||||
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
STUCK = "stuck" # Making no progress
|
||||
STUCK = "stuck" # Making no progress
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@@ -29,6 +30,7 @@ class Problem(BaseModel):
|
||||
|
||||
Problems are surfaced explicitly so Builder can focus on what needs fixing.
|
||||
"""
|
||||
|
||||
id: str
|
||||
severity: str = Field(description="critical, warning, or minor")
|
||||
description: str
|
||||
@@ -42,6 +44,7 @@ class Problem(BaseModel):
|
||||
|
||||
class RunMetrics(BaseModel):
|
||||
"""Quantitative metrics about a run."""
|
||||
|
||||
total_decisions: int = 0
|
||||
successful_decisions: int = 0
|
||||
failed_decisions: int = 0
|
||||
@@ -68,6 +71,7 @@ class Run(BaseModel):
|
||||
|
||||
Contains all decisions, problems, and metrics from a single run.
|
||||
"""
|
||||
|
||||
id: str
|
||||
goal_id: str
|
||||
started_at: datetime = Field(default_factory=datetime.now)
|
||||
@@ -191,6 +195,7 @@ class RunSummary(BaseModel):
|
||||
|
||||
This is what I (Builder) want to see first when analyzing runs.
|
||||
"""
|
||||
|
||||
run_id: str
|
||||
goal_id: str
|
||||
status: RunStatus
|
||||
|
||||
@@ -8,7 +8,7 @@ Uses Pydantic's built-in serialization.
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from framework.schemas.run import Run, RunSummary, RunStatus
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
|
||||
|
||||
class FileStorage:
|
||||
@@ -46,6 +46,40 @@ class FileStorage:
|
||||
for d in dirs:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _validate_key(self, key: str) -> None:
|
||||
"""
|
||||
Validate key to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
key: The key to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If key contains path traversal or dangerous patterns
|
||||
"""
|
||||
if not key or key.strip() == "":
|
||||
raise ValueError("Key cannot be empty")
|
||||
|
||||
# Block path separators
|
||||
if "/" in key or "\\" in key:
|
||||
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
|
||||
|
||||
# Block parent directory references
|
||||
if ".." in key or key.startswith("."):
|
||||
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
|
||||
|
||||
# Block absolute paths
|
||||
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
|
||||
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
|
||||
|
||||
# Block null bytes (Unix path injection)
|
||||
if "\x00" in key:
|
||||
raise ValueError("Invalid key format: null bytes not allowed")
|
||||
|
||||
# Block other dangerous special characters
|
||||
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
|
||||
if any(char in key for char in dangerous_chars):
|
||||
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
|
||||
|
||||
# === RUN OPERATIONS ===
|
||||
|
||||
def save_run(self, run: Run) -> None:
|
||||
@@ -140,6 +174,7 @@ class FileStorage:
|
||||
|
||||
def _get_index(self, index_type: str, key: str) -> list[str]:
|
||||
"""Get values from an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
@@ -148,8 +183,9 @@ class FileStorage:
|
||||
|
||||
def _add_to_index(self, index_type: str, key: str, value: str) -> None:
|
||||
"""Add a value to an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
values = self._get_index(index_type, key)
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value not in values:
|
||||
values.append(value)
|
||||
with open(index_path, "w") as f:
|
||||
@@ -157,8 +193,9 @@ class FileStorage:
|
||||
|
||||
def _remove_from_index(self, index_type: str, key: str, value: str) -> None:
|
||||
"""Remove a value from an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
values = self._get_index(index_type, key)
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value in values:
|
||||
values.remove(value)
|
||||
with open(index_path, "w") as f:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user