Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2cb54595c9 | |||
| 284079d18b | |||
| 1c9b09fb78 | |||
| 9fb14f23d2 | |||
| 4795dc4f68 | |||
| acf0f804c5 | |||
| 4e2951854b | |||
| 80dfb429d7 | |||
| 9c0ba77e22 | |||
| 46b4651073 | |||
| 86dd5246c6 | |||
| a1227c88ee | |||
| 535d7ab568 | |||
| af10494b31 | |||
| 39c1042827 | |||
| 16e7dc11f4 | |||
| 7a27babefd | |||
| d53ae9d51d | |||
| 910cf7727d | |||
| 1698605f15 | |||
| eda124a123 | |||
| 15e9ce8d2f | |||
| c01dd603d7 | |||
| 9d5157d69f | |||
| d78795bdf5 | |||
| ff2b7f473e | |||
| 73c9a91811 | |||
| 27b765d902 | |||
| fddba419be | |||
| f42d6308e8 | |||
| c167002754 | |||
| ea26ee7d0c | |||
| 5280e908b2 | |||
| 1c5dd8c664 | |||
| 3aca153be5 | |||
| 65c8e1653c | |||
| 58e4fa918c | |||
| 3af13d3f90 | |||
| d2eb86e534 | |||
| 03842353e4 | |||
| 48747e20af | |||
| 58af593af6 | |||
| 450575a927 | |||
| eac2bb19b2 | |||
| 756a815bf0 | |||
| 23a7b080eb | |||
| bf39bcdec9 | |||
| 0276632491 | |||
| d14d71f760 | |||
| ef6efc2f55 | |||
| 07e4b593dd | |||
| 497591bf3b | |||
| a2a3e334d6 | |||
| 1ccbfaf800 | |||
| a9afa0555c | |||
| 83b2183cf0 | |||
| f49e7a760e | |||
| 6e0255ebec | |||
| 379d3df46b | |||
| 491a3f24da | |||
| c7d70e0fb1 | |||
| d59f8e99cb | |||
| 0a91b49417 | |||
| ced64541b9 | |||
| 3c30cfe02b | |||
| 0d6267bcf1 | |||
| b47175d1df | |||
| 6f23a30eed | |||
| ff7b5c7e27 | |||
| 19f7ae862e | |||
| 5e9f74744a | |||
| 7787179a5a | |||
| 22bb07f00e | |||
| 660f883197 | |||
| 988de80b66 | |||
| dc6aa226ee | |||
| a7b6b080ab | |||
| 9202cbd4d4 | |||
| 1db8484402 | |||
| cdaec8a837 |
@@ -68,7 +68,6 @@ temp/
|
||||
exports/*
|
||||
|
||||
.claude/settings.local.json
|
||||
.claude/skills/ship-it/
|
||||
|
||||
.venv
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
.PHONY: lint format check test install-hooks help frontend-install frontend-dev frontend-build
|
||||
.PHONY: lint format check test test-tools test-live test-all install-hooks help frontend-install frontend-dev frontend-build
|
||||
|
||||
# ── Ensure uv is findable in Git Bash on Windows ──────────────────────────────
|
||||
# uv installs to ~/.local/bin on Windows/Linux/macOS. Git Bash may not include
|
||||
# this in PATH by default, so we prepend it here.
|
||||
export PATH := $(HOME)/.local/bin:$(PATH)
|
||||
|
||||
# ── Targets ───────────────────────────────────────────────────────────────────
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \
|
||||
@@ -46,4 +53,4 @@ frontend-dev: ## Start frontend dev server
|
||||
cd core/frontend && npm run dev
|
||||
|
||||
frontend-build: ## Build frontend for production
|
||||
cd core/frontend && npm run build
|
||||
cd core/frontend && npm run build
|
||||
@@ -41,7 +41,8 @@ Generate a swarm of worker agents with a coding agent(queen) that control them.
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
[](https://www.youtube.com/watch?v=XDOG9fOaLjU)
|
||||
https://github.com/user-attachments/assets/aad3a035-e7b3-4cac-b13d-4a83c7002c30
|
||||
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
perf: reduce subprocess spawning in quickstart scripts (#4427)
|
||||
|
||||
## Problem
|
||||
Windows process creation (CreateProcess) is 10-100x slower than Linux fork/exec.
|
||||
The quickstart scripts were spawning 4+ separate `uv run python -c "import X"`
|
||||
processes to verify imports, adding ~600ms overhead on Windows.
|
||||
|
||||
## Solution
|
||||
Consolidated all import checks into a single batch script that checks multiple
|
||||
modules in one subprocess call, reducing spawn overhead by ~75%.
|
||||
|
||||
## Changes
|
||||
- **New**: `scripts/check_requirements.py` - Batched import checker
|
||||
- **New**: `scripts/test_check_requirements.py` - Test suite
|
||||
- **New**: `scripts/benchmark_quickstart.ps1` - Performance benchmark tool
|
||||
- **Modified**: `quickstart.ps1` - Updated import verification (2 sections)
|
||||
- **Modified**: `quickstart.sh` - Updated import verification
|
||||
|
||||
## Performance Impact
|
||||
**Benchmark results on Windows:**
|
||||
- Before: ~19.8 seconds for import checks
|
||||
- After: ~4.9 seconds for import checks
|
||||
- **Improvement: 14.9 seconds saved (75.2% faster)**
|
||||
|
||||
## Testing
|
||||
- ✅ All functional tests pass (`scripts/test_check_requirements.py`)
|
||||
- ✅ Quickstart scripts work correctly on Windows
|
||||
- ✅ Error handling verified (invalid imports reported correctly)
|
||||
- ✅ Performance benchmark confirms 75%+ improvement
|
||||
|
||||
Fixes #4427
|
||||
@@ -1,740 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
EventLoopNode WebSocket Demo
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams EventLoopNode execution to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/event_loop_wss_demo.py
|
||||
|
||||
Then open http://localhost:8765 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
import os # noqa: E402
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state (shared across WebSocket connections)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_demo_"))
|
||||
STORE = FileConversationStore(STORE_DIR / "conversation")
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — real tools via ToolRegistry (same pattern as GraphExecutor)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
# Credential store: Aden sync (OAuth2 tokens) + encrypted files + env var fallback
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_local_storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
|
||||
if os.environ.get("ADEN_API_KEY"):
|
||||
try:
|
||||
from framework.credentials.aden import ( # noqa: E402
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
_client = AdenCredentialClient(AdenClientConfig(base_url="https://api.adenhq.com"))
|
||||
_provider = AdenSyncProvider(client=_client)
|
||||
_storage = AdenCachedStorage(
|
||||
local_storage=_local_storage,
|
||||
aden_provider=_provider,
|
||||
)
|
||||
_cred_store = CredentialStore(storage=_storage, providers=[_provider], auto_refresh=True)
|
||||
_synced = _provider.sync_all(_cred_store)
|
||||
logger.info("Synced %d credentials from Aden", _synced)
|
||||
except Exception as e:
|
||||
logger.warning("Aden sync unavailable: %s", e)
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
else:
|
||||
logger.info("ADEN_API_KEY not set, using local credential storage")
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
|
||||
CREDENTIALS = CredentialStoreAdapter(_cred_store)
|
||||
|
||||
# Debug: log which credentials resolved
|
||||
for _name in ["brave_search", "hubspot", "anthropic"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# --- web_search (Brave Search API) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={"X-Subscription-Token": api_key, "Accept": "application/json"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
# --- web_scrape (httpx + BeautifulSoup, no playwright for sync compat) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(url, timeout=30.0, follow_redirects=True, headers=_SCRAPE_HEADERS)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {"url": url, "title": title, "content": text, "length": len(text)}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
# --- HubSpot CRM tools (optional, requires HUBSPOT_ACCESS_TOKEN) ---
|
||||
|
||||
_HUBSPOT_API = "https://api.hubapi.com"
|
||||
|
||||
|
||||
def _hubspot_headers() -> dict | None:
|
||||
token = CREDENTIALS.get("hubspot")
|
||||
if token:
|
||||
logger.debug("HubSpot token: %s...%s (len=%d)", token[:8], token[-4:], len(token))
|
||||
else:
|
||||
logger.debug("HubSpot token: not found")
|
||||
if not token:
|
||||
return None
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _exec_hubspot_search(inputs: dict) -> dict:
|
||||
headers = _hubspot_headers()
|
||||
if not headers:
|
||||
return {"error": "HUBSPOT_ACCESS_TOKEN not set"}
|
||||
object_type = inputs.get("object_type", "contacts")
|
||||
query = inputs.get("query", "")
|
||||
limit = min(inputs.get("limit", 10), 100)
|
||||
body: dict = {"limit": limit}
|
||||
if query:
|
||||
body["query"] = query
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_HUBSPOT_API}/crm/v3/objects/{object_type}/search",
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HubSpot API HTTP {resp.status_code}: {resp.text[:200]}"}
|
||||
return resp.json()
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"HubSpot error: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="hubspot_search",
|
||||
tool=Tool(
|
||||
name="hubspot_search",
|
||||
description=(
|
||||
"Search HubSpot CRM objects (contacts, companies, or deals). "
|
||||
"Returns matching records with their properties."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"object_type": {
|
||||
"type": "string",
|
||||
"description": "CRM object type: 'contacts', 'companies', or 'deals'",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (name, email, domain, etc.)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (1-100, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["object_type"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_hubspot_search(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page (embedded)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>EventLoopNode Live Demo</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117; color: #c9d1d9;
|
||||
height: 100vh; display: flex; flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22; padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex; align-items: center; gap: 16px;
|
||||
}
|
||||
header h1 { font-size: 16px; color: #58a6ff; font-weight: 600; }
|
||||
.status {
|
||||
font-size: 12px; padding: 3px 10px; border-radius: 12px;
|
||||
background: #21262d; color: #8b949e;
|
||||
}
|
||||
.status.running { background: #1a4b2e; color: #3fb950; }
|
||||
.status.done { background: #1a3a5c; color: #58a6ff; }
|
||||
.status.error { background: #4b1a1a; color: #f85149; }
|
||||
.chat { flex: 1; overflow-y: auto; padding: 16px; }
|
||||
.msg {
|
||||
margin: 8px 0; padding: 10px 14px; border-radius: 8px;
|
||||
line-height: 1.6; white-space: pre-wrap; word-wrap: break-word;
|
||||
}
|
||||
.msg.user { background: #1a3a5c; color: #58a6ff; }
|
||||
.msg.assistant { background: #161b22; color: #c9d1d9; }
|
||||
.msg.event {
|
||||
background: transparent; color: #8b949e; font-size: 11px;
|
||||
padding: 4px 14px; border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop { border-left-color: #58a6ff; }
|
||||
.msg.event.tool { border-left-color: #d29922; }
|
||||
.msg.event.stall { border-left-color: #f85149; }
|
||||
.input-bar {
|
||||
padding: 12px 16px; background: #161b22;
|
||||
border-top: 1px solid #30363d; display: flex; gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1; background: #0d1117; border: 1px solid #30363d;
|
||||
color: #c9d1d9; padding: 8px 12px; border-radius: 6px;
|
||||
font-family: inherit; font-size: 14px; outline: none;
|
||||
}
|
||||
.input-bar input:focus { border-color: #58a6ff; }
|
||||
.input-bar button {
|
||||
background: #238636; color: #fff; border: none;
|
||||
padding: 8px 20px; border-radius: 6px; cursor: pointer;
|
||||
font-family: inherit; font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover { background: #2ea043; }
|
||||
.input-bar button:disabled {
|
||||
background: #21262d; color: #484f58; cursor: not-allowed;
|
||||
}
|
||||
.input-bar button.clear { background: #da3633; }
|
||||
.input-bar button.clear:hover { background: #f85149; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>EventLoopNode Live</h1>
|
||||
<span id="status" class="status">Idle</span>
|
||||
<span id="iter" class="status" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Ask anything..." autofocus />
|
||||
<button id="go" onclick="run()">Send</button>
|
||||
<button class="clear"
|
||||
onclick="clearConversation()">Clear</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
const chat = document.getElementById('chat');
|
||||
const status = document.getElementById('status');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setStatus(text, cls) {
|
||||
status.textContent = text;
|
||||
status.className = 'status ' + cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setStatus('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setStatus('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'ready') {
|
||||
setStatus('Ready', 'done');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg('RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls);
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
}
|
||||
else if (evt.type === 'result') {
|
||||
setStatus('Session ended', evt.success ? 'done' : 'error');
|
||||
if (evt.error) addMsg('ERROR ' + evt.error, 'event stall');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
else if (evt.type === 'cleared') {
|
||||
chat.innerHTML = '';
|
||||
iterCount = 0;
|
||||
iterEl.textContent = 'Step 0';
|
||||
iterEl.style.display = 'none';
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
setStatus('Running', 'running');
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
function clearConversation() {
|
||||
if (ws && ws.readyState === 1) {
|
||||
ws.send(JSON.stringify({ command: 'clear' }));
|
||||
}
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Persistent WebSocket: long-lived EventLoopNode with client_facing blocking."""
|
||||
global STORE
|
||||
|
||||
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
# -- Per-connection state -----------------------------------------------
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="assistant",
|
||||
name="Chat Assistant",
|
||||
description="A conversational assistant that remembers context across messages",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to tools. "
|
||||
"You can search the web, scrape webpages, and query HubSpot CRM. "
|
||||
"Use tools when the user asks for current information or external data. "
|
||||
"You have full conversation history, so you can reference previous messages."
|
||||
),
|
||||
)
|
||||
|
||||
# -- Ready callback: subscribe to CLIENT_INPUT_REQUESTED on the bus ---
|
||||
async def on_input_requested(event):
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=on_input_requested,
|
||||
)
|
||||
|
||||
async def start_loop(first_message: str):
|
||||
"""Create an EventLoopNode and run it as a background task."""
|
||||
nonlocal node, loop_task
|
||||
|
||||
memory = SharedMemory()
|
||||
ctx = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="assistant",
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=10_000, max_context_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
await node.inject_event(first_message)
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"error": result.error,
|
||||
"tokens": result.tokens_used,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Loop ended: success={result.success}, tokens={result.tokens_used}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("Loop stopped: WebSocket closed")
|
||||
except Exception as e:
|
||||
logger.exception("Loop error")
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"output": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
loop_task = asyncio.create_task(_run())
|
||||
|
||||
async def stop_loop():
|
||||
"""Signal the node and wait for the loop task to finish."""
|
||||
nonlocal node, loop_task
|
||||
if loop_task and not loop_task.done():
|
||||
if node:
|
||||
node.signal_shutdown()
|
||||
try:
|
||||
await asyncio.wait_for(loop_task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
loop_task.cancel()
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
# -- Message loop (runs for the lifetime of this WebSocket) -------------
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Clear command
|
||||
if msg.get("command") == "clear":
|
||||
import shutil
|
||||
|
||||
await stop_loop()
|
||||
await STORE.close()
|
||||
conv_dir = STORE_DIR / "conversation"
|
||||
if conv_dir.exists():
|
||||
shutil.rmtree(conv_dir)
|
||||
STORE = FileConversationStore(conv_dir)
|
||||
await websocket.send(json.dumps({"type": "cleared"}))
|
||||
logger.info("Conversation cleared")
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
if node is None:
|
||||
# First message — spin up the loop
|
||||
logger.info(f"Starting persistent loop: {topic}")
|
||||
await start_loop(topic)
|
||||
else:
|
||||
# Subsequent message — inject into the running loop
|
||||
logger.info(f"Injecting message: {topic}")
|
||||
await node.inject_event(topic)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await stop_loop()
|
||||
logger.info("WebSocket closed, loop stopped")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler for serving the HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None # let websockets handle the upgrade
|
||||
# Serve the HTML page for any other path
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8765
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Demo running at http://localhost:{port}")
|
||||
logger.info("Open in your browser and enter a topic to research.")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,930 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Two-Node ContextHandoff Demo
|
||||
|
||||
Demonstrates ContextHandoff between two EventLoopNode instances:
|
||||
Node A (Researcher) → ContextHandoff → Node B (Analyst)
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams both nodes to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/handoff_demo.py
|
||||
|
||||
Then open http://localhost:8766 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.context_handoff import ContextHandoff # noqa: E402
|
||||
from framework.graph.conversation import NodeConversation # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("handoff_demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_handoff_"))
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Credentials
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Composite credential store: encrypted files (primary) + env vars (fallback)
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_composite = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
|
||||
|
||||
for _name in ["brave_search", "hubspot"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — web_search + web_scrape for Node A (Researcher)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
headers=_SCRAPE_HEADERS,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": text,
|
||||
"length": len(text),
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Node Specs
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
RESEARCHER_SPEC = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches a topic using web search and scraping tools",
|
||||
node_type="event_loop",
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_summary"],
|
||||
system_prompt=(
|
||||
"You are a thorough research assistant. Your job is to research "
|
||||
"the given topic using the web_search and web_scrape tools.\n\n"
|
||||
"1. Search for relevant information on the topic\n"
|
||||
"2. Scrape 1-2 of the most promising URLs for details\n"
|
||||
"3. Synthesize your findings into a comprehensive summary\n"
|
||||
"4. Use set_output with key='research_summary' to save your "
|
||||
"findings\n\n"
|
||||
"Be thorough but efficient. Aim for 2-4 search/scrape calls, "
|
||||
"then summarize and set_output."
|
||||
),
|
||||
)
|
||||
|
||||
ANALYST_SPEC = NodeSpec(
|
||||
id="analyst",
|
||||
name="Analyst",
|
||||
description="Analyzes research findings and provides insights",
|
||||
node_type="event_loop",
|
||||
input_keys=["context"],
|
||||
output_keys=["analysis"],
|
||||
system_prompt=(
|
||||
"You are a strategic analyst. You receive research findings from "
|
||||
"a previous researcher and must:\n\n"
|
||||
"1. Identify key themes and patterns\n"
|
||||
"2. Assess the reliability and significance of the findings\n"
|
||||
"3. Provide actionable insights and recommendations\n"
|
||||
"4. Use set_output with key='analysis' to save your analysis\n\n"
|
||||
"Be concise but insightful. Focus on what matters most."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>ContextHandoff Demo</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117;
|
||||
color: #c9d1d9;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22;
|
||||
padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
header h1 {
|
||||
font-size: 16px;
|
||||
color: #58a6ff;
|
||||
font-weight: 600;
|
||||
}
|
||||
.badge {
|
||||
font-size: 12px;
|
||||
padding: 3px 10px;
|
||||
border-radius: 12px;
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.researcher {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.badge.analyst {
|
||||
background: #1a4b2e;
|
||||
color: #3fb950;
|
||||
}
|
||||
.badge.handoff {
|
||||
background: #3d1f00;
|
||||
color: #d29922;
|
||||
}
|
||||
.badge.done {
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.error {
|
||||
background: #4b1a1a;
|
||||
color: #f85149;
|
||||
}
|
||||
.chat {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
}
|
||||
.msg {
|
||||
margin: 8px 0;
|
||||
padding: 10px 14px;
|
||||
border-radius: 8px;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.msg.user {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.msg.assistant {
|
||||
background: #161b22;
|
||||
color: #c9d1d9;
|
||||
}
|
||||
.msg.assistant.analyst-msg {
|
||||
border-left: 3px solid #3fb950;
|
||||
}
|
||||
.msg.event {
|
||||
background: transparent;
|
||||
color: #8b949e;
|
||||
font-size: 11px;
|
||||
padding: 4px 14px;
|
||||
border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop {
|
||||
border-left-color: #58a6ff;
|
||||
}
|
||||
.msg.event.tool {
|
||||
border-left-color: #d29922;
|
||||
}
|
||||
.msg.event.stall {
|
||||
border-left-color: #f85149;
|
||||
}
|
||||
.handoff-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #1c1200;
|
||||
border: 1px solid #d29922;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.handoff-banner h3 {
|
||||
color: #d29922;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.handoff-banner p, .result-banner p {
|
||||
color: #8b949e;
|
||||
font-size: 12px;
|
||||
line-height: 1.5;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap;
|
||||
text-align: left;
|
||||
}
|
||||
.result-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #0a2614;
|
||||
border: 1px solid #3fb950;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.result-banner h3 {
|
||||
color: #3fb950;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.result-banner .label {
|
||||
color: #58a6ff;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
.result-banner .tokens {
|
||||
color: #484f58;
|
||||
font-size: 11px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.input-bar {
|
||||
padding: 12px 16px;
|
||||
background: #161b22;
|
||||
border-top: 1px solid #30363d;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1;
|
||||
background: #0d1117;
|
||||
border: 1px solid #30363d;
|
||||
color: #c9d1d9;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
font-family: inherit;
|
||||
font-size: 14px;
|
||||
outline: none;
|
||||
}
|
||||
.input-bar input:focus {
|
||||
border-color: #58a6ff;
|
||||
}
|
||||
.input-bar button {
|
||||
background: #238636;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 20px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover {
|
||||
background: #2ea043;
|
||||
}
|
||||
.input-bar button:disabled {
|
||||
background: #21262d;
|
||||
color: #484f58;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>ContextHandoff Demo</h1>
|
||||
<span id="phase" class="badge">Idle</span>
|
||||
<span id="iter" class="badge" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Enter a research topic..." autofocus />
|
||||
<button id="go" onclick="run()">Research</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
let currentPhase = 'idle';
|
||||
const chat = document.getElementById('chat');
|
||||
const phase = document.getElementById('phase');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setPhase(text, cls) {
|
||||
phase.textContent = text;
|
||||
phase.className = 'badge ' + cls;
|
||||
currentPhase = cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function addHandoffBanner(summary) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'handoff-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Context Handoff: Researcher -> Analyst';
|
||||
const p = document.createElement('p');
|
||||
p.textContent = summary || 'Passing research context...';
|
||||
banner.appendChild(h3);
|
||||
banner.appendChild(p);
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function addResultBanner(researcher, analyst, tokens) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'result-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Pipeline Complete';
|
||||
banner.appendChild(h3);
|
||||
|
||||
if (researcher && researcher.research_summary) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'RESEARCH SUMMARY';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = researcher.research_summary;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (analyst && analyst.analysis) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'ANALYSIS';
|
||||
lbl.style.color = '#3fb950';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = analyst.analysis;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (tokens) {
|
||||
const t = document.createElement('div');
|
||||
t.className = 'tokens';
|
||||
t.textContent = 'Total tokens: ' + tokens.toLocaleString();
|
||||
banner.appendChild(t);
|
||||
}
|
||||
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setPhase('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setPhase('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setPhase('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'phase') {
|
||||
if (evt.phase === 'researcher') {
|
||||
setPhase('Researcher', 'researcher');
|
||||
} else if (evt.phase === 'handoff') {
|
||||
setPhase('Handoff', 'handoff');
|
||||
} else if (evt.phase === 'analyst') {
|
||||
setPhase('Analyst', 'analyst');
|
||||
}
|
||||
iterCount = 0;
|
||||
iterEl.style.display = 'none';
|
||||
}
|
||||
else if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg(
|
||||
'RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls
|
||||
);
|
||||
var assistCls = currentPhase === 'analyst'
|
||||
? 'assistant analyst-msg' : 'assistant';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'handoff_context') {
|
||||
addHandoffBanner(evt.summary);
|
||||
var assistCls = 'assistant analyst-msg';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'node_result') {
|
||||
if (evt.node_id === 'researcher') {
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'done') {
|
||||
setPhase('Done', 'done');
|
||||
iterEl.style.display = 'none';
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
currentAssistantEl = null;
|
||||
addResultBanner(
|
||||
evt.researcher, evt.analyst, evt.total_tokens
|
||||
);
|
||||
goBtn.disabled = false;
|
||||
inputEl.placeholder = 'Enter another topic...';
|
||||
}
|
||||
else if (evt.type === 'error') {
|
||||
setPhase('Error', 'error');
|
||||
addMsg('ERROR ' + evt.message, 'event stall');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
chat.innerHTML = '';
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler — sequential Node A → Handoff → Node B
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Run the two-node handoff pipeline per user message."""
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
logger.info(f"Starting handoff pipeline for: {topic}")
|
||||
|
||||
try:
|
||||
await _run_pipeline(websocket, topic)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("WebSocket closed during pipeline")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline error")
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_pipeline(websocket, topic: str):
|
||||
"""Execute: Node A (research) → ContextHandoff → Node B (analysis)."""
|
||||
import shutil
|
||||
|
||||
# Fresh stores for each run
|
||||
run_dir = Path(tempfile.mkdtemp(prefix="hive_run_", dir=STORE_DIR))
|
||||
store_a = FileConversationStore(run_dir / "node_a")
|
||||
store_b = FileConversationStore(run_dir / "node_b")
|
||||
|
||||
# Shared event bus
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
# ---- Phase 1: Researcher ------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "researcher"}))
|
||||
|
||||
node_a = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge: accept when output_keys filled
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_context_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
ctx_a = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="researcher",
|
||||
node_spec=RESEARCHER_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"topic": topic},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
result_a = await node_a.execute(ctx_a)
|
||||
logger.info(
|
||||
"Researcher done: success=%s, tokens=%s",
|
||||
result_a.success,
|
||||
result_a.tokens_used,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_result",
|
||||
"node_id": "researcher",
|
||||
"success": result_a.success,
|
||||
"output": result_a.output,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not result_a.success:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Researcher failed: {result_a.error}",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# ---- Phase 2: Context Handoff -------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "handoff"}))
|
||||
|
||||
# Restore the researcher's conversation from store
|
||||
conversation_a = await NodeConversation.restore(store_a)
|
||||
if conversation_a is None:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Failed to restore researcher conversation",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
handoff_engine = ContextHandoff(llm=LLM)
|
||||
handoff_context = handoff_engine.summarize_conversation(
|
||||
conversation=conversation_a,
|
||||
node_id="researcher",
|
||||
output_keys=["research_summary"],
|
||||
)
|
||||
|
||||
formatted_handoff = ContextHandoff.format_as_input(handoff_context)
|
||||
logger.info(
|
||||
"Handoff: %d turns, ~%d tokens, keys=%s",
|
||||
handoff_context.turn_count,
|
||||
handoff_context.total_tokens_used,
|
||||
list(handoff_context.key_outputs.keys()),
|
||||
)
|
||||
|
||||
# Send handoff context to browser
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "handoff_context",
|
||||
"summary": handoff_context.summary[:500],
|
||||
"turn_count": handoff_context.turn_count,
|
||||
"tokens": handoff_context.total_tokens_used,
|
||||
"key_outputs": handoff_context.key_outputs,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Phase 3: Analyst ---------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "analyst"}))
|
||||
|
||||
node_b = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_context_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
)
|
||||
|
||||
ctx_b = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="analyst",
|
||||
node_spec=ANALYST_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"context": formatted_handoff},
|
||||
llm=LLM,
|
||||
available_tools=[],
|
||||
)
|
||||
|
||||
result_b = await node_b.execute(ctx_b)
|
||||
logger.info(
|
||||
"Analyst done: success=%s, tokens=%s",
|
||||
result_b.success,
|
||||
result_b.tokens_used,
|
||||
)
|
||||
|
||||
# ---- Done ---------------------------------------------------------------
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "done",
|
||||
"researcher": result_a.output,
|
||||
"analyst": result_b.output,
|
||||
"total_tokens": ((result_a.tokens_used or 0) + (result_b.tokens_used or 0)),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Clean up temp stores
|
||||
try:
|
||||
shutil.rmtree(run_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8766
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Handoff demo at http://localhost:{port}")
|
||||
logger.info("Enter a research topic to start the pipeline.")
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,25 +23,56 @@ class AgentEntry:
|
||||
last_active: str | None = None
|
||||
|
||||
|
||||
def _get_last_active(agent_name: str) -> str | None:
|
||||
"""Return the most recent updated_at timestamp across all sessions."""
|
||||
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
|
||||
if not sessions_dir.exists():
|
||||
return None
|
||||
def _get_last_active(agent_path: Path) -> str | None:
|
||||
"""Return the most recent updated_at timestamp across all sessions.
|
||||
|
||||
Checks both worker sessions (``~/.hive/agents/{name}/sessions/``) and
|
||||
queen sessions (``~/.hive/queen/session/``) whose ``meta.json`` references
|
||||
the same *agent_path*.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
agent_name = agent_path.name
|
||||
latest: str | None = None
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
|
||||
continue
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamps", {}).get("updated_at")
|
||||
if ts and (latest is None or ts > latest):
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 1. Worker sessions
|
||||
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
|
||||
if sessions_dir.exists():
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
|
||||
continue
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamps", {}).get("updated_at")
|
||||
if ts and (latest is None or ts > latest):
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2. Queen sessions
|
||||
queen_sessions_dir = Path.home() / ".hive" / "queen" / "session"
|
||||
if queen_sessions_dir.exists():
|
||||
resolved = agent_path.resolve()
|
||||
for d in queen_sessions_dir.iterdir():
|
||||
if not d.is_dir():
|
||||
continue
|
||||
meta_file = d / "meta.json"
|
||||
if not meta_file.exists():
|
||||
continue
|
||||
try:
|
||||
meta = json.loads(meta_file.read_text(encoding="utf-8"))
|
||||
stored = meta.get("agent_path")
|
||||
if not stored or Path(stored).resolve() != resolved:
|
||||
continue
|
||||
ts = datetime.fromtimestamp(d.stat().st_mtime).isoformat()
|
||||
if latest is None or ts > latest:
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return latest
|
||||
|
||||
|
||||
@@ -169,7 +200,7 @@ def discover_agents() -> dict[str, list[AgentEntry]]:
|
||||
node_count=node_count,
|
||||
tool_count=tool_count,
|
||||
tags=tags,
|
||||
last_active=_get_last_active(path.name),
|
||||
last_active=_get_last_active(path),
|
||||
)
|
||||
)
|
||||
if entries:
|
||||
|
||||
@@ -0,0 +1,286 @@
|
||||
"""Worker per-run digest (run diary).
|
||||
|
||||
Storage layout:
|
||||
~/.hive/agents/{agent_name}/runs/{run_id}/digest.md
|
||||
|
||||
Each completed or failed worker run gets one digest file. The queen reads
|
||||
these via get_worker_status(focus='diary') before digging into live runtime
|
||||
logs — the diary is a cheap, persistent record that survives across sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DIGEST_SYSTEM = """\
|
||||
You maintain run digests for a worker agent.
|
||||
A run digest is a concise, factual record of a single task execution.
|
||||
|
||||
Write 3-6 sentences covering:
|
||||
- What the worker was asked to do (the task/goal)
|
||||
- What approach it took and what tools it used
|
||||
- What the outcome was (success, partial, or failure — and why if relevant)
|
||||
- Any notable issues, retries, or escalations to the queen
|
||||
|
||||
Write in third person past tense. Be direct and specific.
|
||||
Omit routine tool invocations unless the result matters.
|
||||
Output only the digest prose — no headings, no code fences.
|
||||
"""
|
||||
|
||||
|
||||
def _worker_runs_dir(agent_name: str) -> Path:
|
||||
return Path.home() / ".hive" / "agents" / agent_name / "runs"
|
||||
|
||||
|
||||
def digest_path(agent_name: str, run_id: str) -> Path:
|
||||
return _worker_runs_dir(agent_name) / run_id / "digest.md"
|
||||
|
||||
|
||||
def _collect_run_events(bus: EventBus, run_id: str, limit: int = 2000) -> list[AgentEvent]:
|
||||
"""Collect all events belonging to *run_id* from the bus history.
|
||||
|
||||
Strategy: find the EXECUTION_STARTED event that carries ``run_id``,
|
||||
extract its ``execution_id``, then query the bus by that execution_id.
|
||||
This works because TOOL_CALL_*, EDGE_TRAVERSED, NODE_STALLED etc. carry
|
||||
execution_id but not run_id.
|
||||
|
||||
Falls back to a full-scan run_id filter when EXECUTION_STARTED is not
|
||||
found (e.g. bus was rotated).
|
||||
"""
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
# Pass 1: find execution_id via EXECUTION_STARTED with matching run_id
|
||||
started = bus.get_history(event_type=EventType.EXECUTION_STARTED, limit=limit)
|
||||
exec_id: str | None = None
|
||||
for e in started:
|
||||
if getattr(e, "run_id", None) == run_id and e.execution_id:
|
||||
exec_id = e.execution_id
|
||||
break
|
||||
|
||||
if exec_id:
|
||||
return bus.get_history(execution_id=exec_id, limit=limit)
|
||||
|
||||
# Fallback: scan all events and match by run_id attribute
|
||||
return [e for e in bus.get_history(limit=limit) if getattr(e, "run_id", None) == run_id]
|
||||
|
||||
|
||||
def _build_run_context(
|
||||
events: list[AgentEvent],
|
||||
outcome_event: AgentEvent | None,
|
||||
) -> str:
|
||||
"""Assemble a plain-text run context string for the digest LLM call."""
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
# Reverse so events are in chronological order
|
||||
events_chron = list(reversed(events))
|
||||
|
||||
lines: list[str] = []
|
||||
|
||||
# Task input from EXECUTION_STARTED
|
||||
started = [e for e in events_chron if e.type == EventType.EXECUTION_STARTED]
|
||||
if started:
|
||||
inp = started[0].data.get("input", {})
|
||||
if inp:
|
||||
lines.append(f"Task input: {str(inp)[:400]}")
|
||||
|
||||
# Duration (elapsed so far if no outcome yet)
|
||||
ref_ts = outcome_event.timestamp if outcome_event else datetime.utcnow()
|
||||
if started:
|
||||
elapsed = (ref_ts - started[0].timestamp).total_seconds()
|
||||
m, s = divmod(int(elapsed), 60)
|
||||
lines.append(f"Duration so far: {m}m {s}s" if m else f"Duration so far: {s}s")
|
||||
|
||||
# Outcome
|
||||
if outcome_event is None:
|
||||
lines.append("Status: still running (mid-run snapshot)")
|
||||
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
|
||||
out = outcome_event.data.get("output", {})
|
||||
out_str = f"Outcome: completed. Output: {str(out)[:300]}"
|
||||
lines.append(out_str if out else "Outcome: completed.")
|
||||
else:
|
||||
err = outcome_event.data.get("error", "")
|
||||
lines.append(f"Outcome: failed. Error: {str(err)[:300]}" if err else "Outcome: failed.")
|
||||
|
||||
# Node path (edge traversals)
|
||||
edges = [e for e in events_chron if e.type == EventType.EDGE_TRAVERSED]
|
||||
if edges:
|
||||
parts = [
|
||||
f"{e.data.get('source_node', '?')}->{e.data.get('target_node', '?')}"
|
||||
for e in edges[-20:]
|
||||
]
|
||||
lines.append(f"Node path: {', '.join(parts)}")
|
||||
|
||||
# Tools used
|
||||
tool_events = [e for e in events_chron if e.type == EventType.TOOL_CALL_COMPLETED]
|
||||
if tool_events:
|
||||
names = [e.data.get("tool_name", "?") for e in tool_events]
|
||||
counts = Counter(names)
|
||||
summary = ", ".join(f"{name}×{n}" if n > 1 else name for name, n in counts.most_common())
|
||||
lines.append(f"Tools used: {summary}")
|
||||
# Note any tool errors
|
||||
errors = [e for e in tool_events if e.data.get("is_error")]
|
||||
if errors:
|
||||
err_names = Counter(e.data.get("tool_name", "?") for e in errors)
|
||||
lines.append(f"Tool errors: {dict(err_names)}")
|
||||
|
||||
# Issues
|
||||
issue_map = {
|
||||
EventType.NODE_STALLED: "stall",
|
||||
EventType.NODE_TOOL_DOOM_LOOP: "doom loop",
|
||||
EventType.CONSTRAINT_VIOLATION: "constraint violation",
|
||||
EventType.NODE_RETRY: "retry",
|
||||
}
|
||||
issue_parts: list[str] = []
|
||||
for evt_type, label in issue_map.items():
|
||||
n = sum(1 for e in events_chron if e.type == evt_type)
|
||||
if n:
|
||||
issue_parts.append(f"{n} {label}(s)")
|
||||
if issue_parts:
|
||||
lines.append(f"Issues: {', '.join(issue_parts)}")
|
||||
|
||||
# Escalations to queen
|
||||
escalations = [e for e in events_chron if e.type == EventType.ESCALATION_REQUESTED]
|
||||
if escalations:
|
||||
lines.append(f"Escalations to queen: {len(escalations)}")
|
||||
|
||||
# Final LLM output snippet (last LLM_TEXT_DELTA snapshot)
|
||||
text_events = [e for e in reversed(events_chron) if e.type == EventType.LLM_TEXT_DELTA]
|
||||
if text_events:
|
||||
snapshot = text_events[0].data.get("snapshot", "") or ""
|
||||
if snapshot:
|
||||
lines.append(f"Final LLM output: {snapshot[-400:].strip()}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def consolidate_worker_run(
|
||||
agent_name: str,
|
||||
run_id: str,
|
||||
outcome_event: AgentEvent | None,
|
||||
bus: EventBus,
|
||||
llm: Any,
|
||||
) -> None:
|
||||
"""Write (or overwrite) the digest for a worker run.
|
||||
|
||||
Called fire-and-forget either:
|
||||
- After EXECUTION_COMPLETED / EXECUTION_FAILED (outcome_event set, final write)
|
||||
- Periodically during a run on a cooldown timer (outcome_event=None, mid-run snapshot)
|
||||
|
||||
The digest file is always overwritten so each call produces the freshest view.
|
||||
The final completion/failure call supersedes any mid-run snapshot.
|
||||
|
||||
Args:
|
||||
agent_name: Worker agent directory name (determines storage path).
|
||||
run_id: The run ID.
|
||||
outcome_event: EXECUTION_COMPLETED or EXECUTION_FAILED event, or None for
|
||||
a mid-run snapshot.
|
||||
bus: The session EventBus (shared queen + worker).
|
||||
llm: LLMProvider with an acomplete() method.
|
||||
"""
|
||||
try:
|
||||
events = _collect_run_events(bus, run_id)
|
||||
run_context = _build_run_context(events, outcome_event)
|
||||
if not run_context:
|
||||
logger.debug("worker_memory: no events for run %s, skipping digest", run_id)
|
||||
return
|
||||
|
||||
is_final = outcome_event is not None
|
||||
logger.info(
|
||||
"worker_memory: generating %s digest for run %s ...",
|
||||
"final" if is_final else "mid-run",
|
||||
run_id,
|
||||
)
|
||||
|
||||
from framework.agents.queen.config import default_config
|
||||
|
||||
resp = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": run_context}],
|
||||
system=_DIGEST_SYSTEM,
|
||||
max_tokens=min(default_config.max_tokens, 512),
|
||||
)
|
||||
digest_text = (resp.content or "").strip()
|
||||
if not digest_text:
|
||||
logger.warning("worker_memory: LLM returned empty digest for run %s", run_id)
|
||||
return
|
||||
|
||||
path = digest_path(agent_name, run_id)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
ts = (outcome_event.timestamp if outcome_event else datetime.utcnow()).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
if outcome_event is None:
|
||||
status = "running"
|
||||
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
|
||||
status = "completed"
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
path.write_text(
|
||||
f"# {run_id}\n\n**{ts}** | {status}\n\n{digest_text}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(
|
||||
"worker_memory: %s digest written for run %s (%d chars)",
|
||||
status,
|
||||
run_id,
|
||||
len(digest_text),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
logger.exception("worker_memory: digest failed for run %s", run_id)
|
||||
# Persist the error so it's findable without log access
|
||||
error_path = _worker_runs_dir(agent_name) / run_id / "digest_error.txt"
|
||||
try:
|
||||
error_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
error_path.write_text(
|
||||
f"run_id: {run_id}\ntime: {datetime.now().isoformat()}\n\n{tb}",
|
||||
encoding="utf-8",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def read_recent_digests(agent_name: str, max_runs: int = 5) -> list[tuple[str, str]]:
|
||||
"""Return recent run digests as [(run_id, content), ...], newest first.
|
||||
|
||||
Args:
|
||||
agent_name: Worker agent directory name.
|
||||
max_runs: Maximum number of digests to return.
|
||||
|
||||
Returns:
|
||||
List of (run_id, digest_content) tuples, ordered newest first.
|
||||
"""
|
||||
runs_dir = _worker_runs_dir(agent_name)
|
||||
if not runs_dir.exists():
|
||||
return []
|
||||
|
||||
digest_files = sorted(
|
||||
runs_dir.glob("*/digest.md"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)[:max_runs]
|
||||
|
||||
result: list[tuple[str, str]] = []
|
||||
for f in digest_files:
|
||||
try:
|
||||
content = f.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
result.append((f.parent.name, content))
|
||||
except OSError:
|
||||
continue
|
||||
return result
|
||||
@@ -19,6 +19,10 @@ from framework.graph.edge import DEFAULT_MAX_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
HIVE_CONFIG_FILE = Path.home() / ".hive" / "configuration.json"
|
||||
|
||||
# Hive LLM router endpoint (Anthropic-compatible).
|
||||
# litellm's Anthropic handler appends /v1/messages, so this is just the base host.
|
||||
HIVE_LLM_ENDPOINT = "https://api.adenhq.com"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -47,7 +51,13 @@ def get_preferred_model() -> str:
|
||||
"""Return the user's preferred LLM model string (e.g. 'anthropic/claude-sonnet-4-20250514')."""
|
||||
llm = get_hive_config().get("llm", {})
|
||||
if llm.get("provider") and llm.get("model"):
|
||||
return f"{llm['provider']}/{llm['model']}"
|
||||
provider = str(llm["provider"])
|
||||
model = str(llm["model"]).strip()
|
||||
# OpenRouter quickstart stores raw model IDs; tolerate pasted "openrouter/<id>" too.
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return "anthropic/claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
@@ -57,6 +67,7 @@ def get_max_tokens() -> int:
|
||||
|
||||
|
||||
DEFAULT_MAX_CONTEXT_TOKENS = 32_000
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def get_max_context_tokens() -> int:
|
||||
@@ -138,7 +149,11 @@ def get_api_base() -> str | None:
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
return "https://api.kimi.com/coding"
|
||||
return llm.get("api_base")
|
||||
if llm.get("api_base"):
|
||||
return llm["api_base"]
|
||||
if str(llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_extra_kwargs() -> dict[str, Any]:
|
||||
|
||||
@@ -51,6 +51,16 @@ def ensure_credential_key_env() -> None:
|
||||
if found and value:
|
||||
os.environ[var_name] = value
|
||||
logger.debug("Loaded %s from shell config", var_name)
|
||||
# Also load the currently configured LLM env var even if it's not in CREDENTIAL_SPECS.
|
||||
# This keeps quickstart-written keys available to fresh processes on Unix shells.
|
||||
from framework.config import get_hive_config
|
||||
|
||||
llm_env_var = str(get_hive_config().get("llm", {}).get("api_key_env_var", "")).strip()
|
||||
if llm_env_var and not os.environ.get(llm_env_var):
|
||||
found, value = check_env_var_in_shell_config(llm_env_var)
|
||||
if found and value:
|
||||
os.environ[llm_env_var] = value
|
||||
logger.debug("Loaded configured LLM env var %s from shell config", llm_env_var)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -612,6 +612,11 @@ class NodeConversation:
|
||||
continue # never prune errors
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
# Tiny results (set_output acks, confirmations) — pruning
|
||||
# saves negligible space but makes the LLM think the call
|
||||
# failed, causing costly retries.
|
||||
if len(msg.content) < 100:
|
||||
continue
|
||||
|
||||
# Phase-aware: protect current phase messages
|
||||
if self._current_phase and msg.phase_id == self._current_phase:
|
||||
@@ -901,8 +906,7 @@ class NodeConversation:
|
||||
full_path = str((spill_path / conv_filename).resolve())
|
||||
ref_parts.append(
|
||||
f"[Previous conversation saved to '{full_path}'. "
|
||||
f"Use load_data('{conv_filename}'), read_file('{full_path}'), "
|
||||
f"or run_command('cat \"{full_path}\"') to review if needed.]"
|
||||
f"Use load_data('{conv_filename}') to review if needed.]"
|
||||
)
|
||||
elif not collapsed_msgs:
|
||||
ref_parts.append("[Previous freeform messages compacted.]")
|
||||
|
||||
@@ -202,6 +202,14 @@ class LoopConfig:
|
||||
max_tool_result_chars: int = 30_000
|
||||
spillover_dir: str | None = None # Path string; created on first use
|
||||
|
||||
# --- set_output value spilling ---
|
||||
# When a set_output value exceeds this character count it is auto-saved
|
||||
# to a file in *spillover_dir* and the stored value is replaced with a
|
||||
# lightweight file reference. This keeps shared memory / adapt.md /
|
||||
# transition markers small and forces the next node to load the full
|
||||
# data from the file. Set to 0 to disable.
|
||||
max_output_value_chars: int = 2_000
|
||||
|
||||
# --- Stream retry (transient error recovery within EventLoopNode) ---
|
||||
# When _run_single_turn() raises a transient error (network, rate limit,
|
||||
# server error), retry up to this many times with exponential backoff
|
||||
@@ -225,6 +233,18 @@ class LoopConfig:
|
||||
cf_grace_turns: int = 1
|
||||
tool_doom_loop_enabled: bool = True
|
||||
|
||||
# --- Per-tool-call timeout ---
|
||||
# Maximum seconds a single tool call may take before being killed.
|
||||
# Prevents hung MCP servers (especially browser/GCU tools) from
|
||||
# blocking the entire event loop indefinitely. 0 = no timeout.
|
||||
tool_call_timeout_seconds: float = 60.0
|
||||
|
||||
# --- Subagent delegation timeout ---
|
||||
# Maximum seconds a delegate_to_sub_agent call may run before being
|
||||
# killed. Subagents run a full event-loop so they naturally take
|
||||
# longer than a single tool call — default is 10 minutes. 0 = no timeout.
|
||||
subagent_timeout_seconds: float = 600.0
|
||||
|
||||
# --- Lifecycle hooks ---
|
||||
# Hooks are async callables keyed by event name. Supported events:
|
||||
# "session_start" — fires once after the first user message is added,
|
||||
@@ -273,13 +293,26 @@ class OutputAccumulator:
|
||||
|
||||
Values are stored in memory and optionally written through to a
|
||||
ConversationStore's cursor data for crash recovery.
|
||||
|
||||
When *spillover_dir* and *max_value_chars* are set, large values are
|
||||
automatically saved to files and replaced with lightweight file
|
||||
references. This guarantees auto-spill fires on **every** ``set()``
|
||||
call regardless of code path (resume, checkpoint restore, etc.).
|
||||
"""
|
||||
|
||||
values: dict[str, Any] = field(default_factory=dict)
|
||||
store: ConversationStore | None = None
|
||||
spillover_dir: str | None = None
|
||||
max_value_chars: int = 0 # 0 = disabled
|
||||
|
||||
async def set(self, key: str, value: Any) -> None:
|
||||
"""Set a key-value pair, persisting immediately if store is available."""
|
||||
"""Set a key-value pair, auto-spilling large values to files.
|
||||
|
||||
When the serialised value exceeds *max_value_chars*, the data is
|
||||
saved to ``<spillover_dir>/output_<key>.<ext>`` and *value* is
|
||||
replaced with a compact file-reference string.
|
||||
"""
|
||||
value = self._auto_spill(key, value)
|
||||
self.values[key] = value
|
||||
if self.store:
|
||||
cursor = await self.store.read_cursor() or {}
|
||||
@@ -288,6 +321,39 @@ class OutputAccumulator:
|
||||
cursor["outputs"] = outputs
|
||||
await self.store.write_cursor(cursor)
|
||||
|
||||
def _auto_spill(self, key: str, value: Any) -> Any:
|
||||
"""Save large values to a file and return a reference string."""
|
||||
if self.max_value_chars <= 0 or not self.spillover_dir:
|
||||
return value
|
||||
|
||||
val_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
||||
if len(val_str) <= self.max_value_chars:
|
||||
return value
|
||||
|
||||
spill_path = Path(self.spillover_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
write_content = (
|
||||
json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (spill_path / filename).stat().st_size
|
||||
logger.info(
|
||||
"set_output value auto-spilled: key=%s, %d chars → %s (%d bytes)",
|
||||
key,
|
||||
len(val_str),
|
||||
filename,
|
||||
file_size,
|
||||
)
|
||||
return (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') "
|
||||
f"to access full data.]"
|
||||
)
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key, or None if not present."""
|
||||
return self.values.get(key)
|
||||
@@ -447,7 +513,11 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation._output_keys = (
|
||||
ctx.cumulative_output_keys or ctx.node_spec.output_keys or None
|
||||
)
|
||||
accumulator = OutputAccumulator(store=self._conversation_store)
|
||||
accumulator = OutputAccumulator(
|
||||
store=self._conversation_store,
|
||||
spillover_dir=self._config.spillover_dir,
|
||||
max_value_chars=self._config.max_output_value_chars,
|
||||
)
|
||||
start_iteration = 0
|
||||
_restored_recent_responses: list[str] = []
|
||||
_restored_tool_fingerprints: list[list[tuple[str, str]]] = []
|
||||
@@ -473,6 +543,8 @@ class EventLoopNode(NodeProtocol):
|
||||
focus_prompt=ctx.node_spec.system_prompt,
|
||||
narrative=ctx.narrative or None,
|
||||
accounts_prompt=ctx.accounts_prompt or None,
|
||||
skills_catalog_prompt=ctx.skills_catalog_prompt or None,
|
||||
protocols_prompt=ctx.protocols_prompt or None,
|
||||
)
|
||||
if conversation.system_prompt != _current_prompt:
|
||||
conversation.update_system_prompt(_current_prompt)
|
||||
@@ -482,9 +554,21 @@ class EventLoopNode(NodeProtocol):
|
||||
_restored_tool_fingerprints = []
|
||||
|
||||
# Fresh conversation: either isolated mode or first node in continuous mode.
|
||||
from framework.graph.prompt_composer import _with_datetime
|
||||
from framework.graph.prompt_composer import (
|
||||
EXECUTION_SCOPE_PREAMBLE,
|
||||
_with_datetime,
|
||||
)
|
||||
|
||||
system_prompt = _with_datetime(ctx.node_spec.system_prompt or "")
|
||||
# Prepend execution-scope preamble for worker nodes so the
|
||||
# LLM knows it is one step in a pipeline and should not try
|
||||
# to perform work that belongs to other nodes.
|
||||
if (
|
||||
not ctx.is_subagent_mode
|
||||
and ctx.node_spec.node_type in ("event_loop", "gcu")
|
||||
and ctx.node_spec.output_keys
|
||||
):
|
||||
system_prompt = f"{EXECUTION_SCOPE_PREAMBLE}\n\n{system_prompt}"
|
||||
# Prepend GCU browser best-practices prompt for gcu nodes
|
||||
if ctx.node_spec.node_type == "gcu":
|
||||
from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT
|
||||
@@ -494,6 +578,22 @@ class EventLoopNode(NodeProtocol):
|
||||
if ctx.accounts_prompt:
|
||||
system_prompt = f"{system_prompt}\n\n{ctx.accounts_prompt}"
|
||||
|
||||
# Append skill catalog and operational protocols
|
||||
if ctx.skills_catalog_prompt:
|
||||
system_prompt = f"{system_prompt}\n\n{ctx.skills_catalog_prompt}"
|
||||
logger.info(
|
||||
"[%s] Injected skills catalog (%d chars)",
|
||||
node_id,
|
||||
len(ctx.skills_catalog_prompt),
|
||||
)
|
||||
if ctx.protocols_prompt:
|
||||
system_prompt = f"{system_prompt}\n\n{ctx.protocols_prompt}"
|
||||
logger.info(
|
||||
"[%s] Injected operational protocols (%d chars)",
|
||||
node_id,
|
||||
len(ctx.protocols_prompt),
|
||||
)
|
||||
|
||||
# Inject agent working memory (adapt.md).
|
||||
# If it doesn't exist yet, seed it with available context.
|
||||
if self._config.spillover_dir:
|
||||
@@ -535,7 +635,11 @@ class EventLoopNode(NodeProtocol):
|
||||
# Stamp phase for first node in continuous mode
|
||||
if _is_continuous:
|
||||
conversation.set_current_phase(ctx.node_id)
|
||||
accumulator = OutputAccumulator(store=self._conversation_store)
|
||||
accumulator = OutputAccumulator(
|
||||
store=self._conversation_store,
|
||||
spillover_dir=self._config.spillover_dir,
|
||||
max_value_chars=self._config.max_output_value_chars,
|
||||
)
|
||||
start_iteration = 0
|
||||
|
||||
# Add initial user message from input data
|
||||
@@ -575,10 +679,24 @@ class EventLoopNode(NodeProtocol):
|
||||
# - Node has sub_agents defined
|
||||
# - We are NOT in subagent mode (prevents nested delegation)
|
||||
if not ctx.is_subagent_mode:
|
||||
sub_agents = getattr(ctx.node_spec, "sub_agents", [])
|
||||
delegate_tool = self._build_delegate_tool(sub_agents, ctx.node_registry)
|
||||
if delegate_tool:
|
||||
tools.append(delegate_tool)
|
||||
sub_agents = getattr(ctx.node_spec, "sub_agents", None) or []
|
||||
if sub_agents:
|
||||
delegate_tool = self._build_delegate_tool(sub_agents, ctx.node_registry)
|
||||
if delegate_tool:
|
||||
tools.append(delegate_tool)
|
||||
logger.info(
|
||||
"[%s] delegate_to_sub_agent injected (sub_agents=%s)",
|
||||
node_id,
|
||||
sub_agents,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"[%s] _build_delegate_tool returned None for sub_agents=%s",
|
||||
node_id,
|
||||
sub_agents,
|
||||
)
|
||||
else:
|
||||
logger.debug("[%s] Skipped delegate tool (is_subagent_mode=True)", node_id)
|
||||
|
||||
# Add report_to_parent tool for sub-agents with a report callback
|
||||
if ctx.is_subagent_mode and ctx.report_callback is not None:
|
||||
@@ -704,6 +822,7 @@ class EventLoopNode(NodeProtocol):
|
||||
)
|
||||
_stream_retry_count = 0
|
||||
_turn_cancelled = False
|
||||
_llm_turn_failed_waiting_input = False
|
||||
while True:
|
||||
try:
|
||||
(
|
||||
@@ -823,6 +942,16 @@ class EventLoopNode(NodeProtocol):
|
||||
# can retry or adjust the request.
|
||||
if ctx.node_spec.client_facing:
|
||||
error_msg = f"LLM call failed: {e}"
|
||||
_guardrail_phrase = (
|
||||
"no endpoints available matching your guardrail restrictions "
|
||||
"and data policy"
|
||||
)
|
||||
if _guardrail_phrase in str(e).lower():
|
||||
error_msg += (
|
||||
" OpenRouter blocked this model under current privacy settings. "
|
||||
"Update https://openrouter.ai/settings/privacy or choose another "
|
||||
"OpenRouter model."
|
||||
)
|
||||
logger.error(
|
||||
"[%s] iter=%d: %s — waiting for user input",
|
||||
node_id,
|
||||
@@ -844,6 +973,7 @@ class EventLoopNode(NodeProtocol):
|
||||
f"[Error: {error_msg}. Please try again.]"
|
||||
)
|
||||
await self._await_user_input(ctx, prompt="")
|
||||
_llm_turn_failed_waiting_input = True
|
||||
break # exit retry loop, continue outer iteration
|
||||
|
||||
# Non-client-facing: crash as before
|
||||
@@ -894,6 +1024,11 @@ class EventLoopNode(NodeProtocol):
|
||||
await self._await_user_input(ctx, prompt="")
|
||||
continue # back to top of for-iteration loop
|
||||
|
||||
# Client-facing non-transient LLM failures wait for user input and then
|
||||
# continue the outer loop without touching per-turn token vars.
|
||||
if _llm_turn_failed_waiting_input:
|
||||
continue
|
||||
|
||||
# 6e'. Feed actual API token count back for accurate estimation
|
||||
turn_input = turn_tokens.get("input", 0)
|
||||
if turn_input > 0:
|
||||
@@ -2144,8 +2279,25 @@ class EventLoopNode(NodeProtocol):
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
key = tc.tool_input.get("key", "")
|
||||
|
||||
# Auto-spill happens inside accumulator.set()
|
||||
# — it fires on every code path (fresh, resume,
|
||||
# restore) and prevents overwrite regression.
|
||||
await accumulator.set(key, value)
|
||||
self._record_learning(key, value)
|
||||
stored = accumulator.get(key)
|
||||
# If the accumulator spilled, update the tool
|
||||
# result so the LLM knows data was saved to a file.
|
||||
if isinstance(stored, str) and stored.startswith("[Saved to '"):
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
f"Output '{key}' auto-saved to file "
|
||||
f"(value was too large for inline). "
|
||||
f"{stored}"
|
||||
),
|
||||
is_error=False,
|
||||
)
|
||||
self._record_learning(key, stored)
|
||||
outputs_set_this_turn.append(key)
|
||||
await self._publish_output_key_set(stream_id, node_id, key, execution_id)
|
||||
logged_tool_calls.append(
|
||||
@@ -2163,7 +2315,6 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
elif tc.tool_name == "ask_user":
|
||||
# --- Framework-level ask_user handling ---
|
||||
user_input_requested = True
|
||||
ask_user_prompt = tc.tool_input.get("question", "")
|
||||
raw_options = tc.tool_input.get("options", None)
|
||||
# Defensive: ensure options is a list of strings.
|
||||
@@ -2200,6 +2351,8 @@ class EventLoopNode(NodeProtocol):
|
||||
user_input_requested = False
|
||||
continue
|
||||
|
||||
user_input_requested = True
|
||||
|
||||
# Free-form ask_user (no options): stream the question
|
||||
# text as a chat message so the user can see it. When
|
||||
# options are present the QuestionWidget shows the
|
||||
@@ -2225,7 +2378,6 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
elif tc.tool_name == "ask_user_multiple":
|
||||
# --- Framework-level ask_user_multiple ---
|
||||
user_input_requested = True
|
||||
raw_questions = tc.tool_input.get("questions", [])
|
||||
if not isinstance(raw_questions, list) or len(raw_questions) < 2:
|
||||
result = ToolResult(
|
||||
@@ -2263,6 +2415,8 @@ class EventLoopNode(NodeProtocol):
|
||||
}
|
||||
)
|
||||
|
||||
user_input_requested = True
|
||||
|
||||
# Store as multi-question prompt/options for
|
||||
# the event emission path
|
||||
ask_user_prompt = ""
|
||||
@@ -2447,21 +2601,44 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
# Phase 2b: execute subagent delegations in parallel.
|
||||
if pending_subagent:
|
||||
_subagent_timeout = self._config.subagent_timeout_seconds
|
||||
|
||||
async def _timed_subagent(
|
||||
_ctx: NodeContext,
|
||||
_tc: ToolCallEvent,
|
||||
_acc: OutputAccumulator = accumulator,
|
||||
_timeout: float = _subagent_timeout,
|
||||
) -> tuple[ToolResult | BaseException, str, float]:
|
||||
_s = time.time()
|
||||
_iso = datetime.now(UTC).isoformat()
|
||||
try:
|
||||
_r = await self._execute_subagent(
|
||||
_coro = self._execute_subagent(
|
||||
_ctx,
|
||||
_tc.tool_input.get("agent_id", ""),
|
||||
_tc.tool_input.get("task", ""),
|
||||
accumulator=_acc,
|
||||
)
|
||||
if _timeout > 0:
|
||||
_r = await asyncio.wait_for(_coro, timeout=_timeout)
|
||||
else:
|
||||
_r = await _coro
|
||||
except TimeoutError:
|
||||
_agent_id = _tc.tool_input.get("agent_id", "unknown")
|
||||
logger.warning(
|
||||
"Subagent '%s' timed out after %.0fs",
|
||||
_agent_id,
|
||||
_timeout,
|
||||
)
|
||||
_r = ToolResult(
|
||||
tool_use_id=_tc.tool_use_id,
|
||||
content=(
|
||||
f"Subagent '{_agent_id}' timed out after "
|
||||
f"{_timeout:.0f}s. The delegation took "
|
||||
"too long and was cancelled. Try a simpler task "
|
||||
"or break it into smaller pieces."
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
except BaseException as _exc:
|
||||
_r = _exc
|
||||
_dur = round(time.time() - _s, 3)
|
||||
@@ -2501,6 +2678,11 @@ class EventLoopNode(NodeProtocol):
|
||||
content=raw.content,
|
||||
is_error=raw.is_error,
|
||||
)
|
||||
# Route through _truncate_tool_result so large
|
||||
# subagent results are saved to spillover files
|
||||
# and survive pruning (instead of being "cleared
|
||||
# from context" with no recovery path).
|
||||
result = self._truncate_tool_result(result, "delegate_to_sub_agent")
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
logged_tool_calls.append(
|
||||
{
|
||||
@@ -2545,7 +2727,11 @@ class EventLoopNode(NodeProtocol):
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
)
|
||||
if tc.tool_name in ("ask_user", "ask_user_multiple"):
|
||||
if (
|
||||
tc.tool_name in ("ask_user", "ask_user_multiple")
|
||||
and user_input_requested
|
||||
and not result.is_error
|
||||
):
|
||||
# Defer tool_call_completed until after user responds
|
||||
self._deferred_tool_complete = {
|
||||
"stream_id": stream_id,
|
||||
@@ -2804,6 +2990,12 @@ class EventLoopNode(NodeProtocol):
|
||||
name="set_output",
|
||||
description=(
|
||||
"Set an output value for this node. Call once per output key. "
|
||||
"Use this for brief notes, counts, status, and file references — "
|
||||
"NOT for large data payloads. When a tool result was saved to a "
|
||||
"data file, pass the filename as the value "
|
||||
"(e.g. 'google_sheets_get_values_1.txt') so the next phase can "
|
||||
"load the full data. Values exceeding ~2000 characters are "
|
||||
"auto-saved to data files. "
|
||||
f"Valid keys: {output_keys}"
|
||||
),
|
||||
parameters={
|
||||
@@ -2816,7 +3008,10 @@ class EventLoopNode(NodeProtocol):
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "The output value to store.",
|
||||
"description": (
|
||||
"The output value — a brief note, count, status, "
|
||||
"or data filename reference."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
@@ -3340,7 +3535,14 @@ class EventLoopNode(NodeProtocol):
|
||||
return False, ""
|
||||
|
||||
async def _execute_tool(self, tc: ToolCallEvent) -> ToolResult:
|
||||
"""Execute a tool call, handling both sync and async executors."""
|
||||
"""Execute a tool call, handling both sync and async executors.
|
||||
|
||||
Applies ``tool_call_timeout_seconds`` from LoopConfig to prevent
|
||||
hung MCP servers from blocking the event loop indefinitely.
|
||||
The initial executor call is offloaded to a thread pool so that
|
||||
sync executors (MCP STDIO tools that block on ``future.result()``)
|
||||
don't freeze the event loop.
|
||||
"""
|
||||
if self._tool_executor is None:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
@@ -3348,9 +3550,35 @@ class EventLoopNode(NodeProtocol):
|
||||
is_error=True,
|
||||
)
|
||||
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
|
||||
result = self._tool_executor(tool_use)
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
result = await result
|
||||
timeout = self._config.tool_call_timeout_seconds
|
||||
|
||||
async def _run() -> ToolResult:
|
||||
# Offload the executor call to a thread. Sync MCP executors
|
||||
# block on future.result() — running in a thread keeps the
|
||||
# event loop free so asyncio.wait_for can fire the timeout.
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self._tool_executor, tool_use)
|
||||
# Async executors return a coroutine — await it on the loop
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
try:
|
||||
if timeout > 0:
|
||||
result = await asyncio.wait_for(_run(), timeout=timeout)
|
||||
else:
|
||||
result = await _run()
|
||||
except TimeoutError:
|
||||
logger.warning("Tool '%s' timed out after %.0fs", tc.tool_name, timeout)
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
f"Tool '{tc.tool_name}' timed out after {timeout:.0f}s. "
|
||||
"The operation took too long and was cancelled. "
|
||||
"Try a simpler request or a different approach."
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
return result
|
||||
|
||||
def _record_learning(self, key: str, value: Any) -> None:
|
||||
@@ -3421,6 +3649,125 @@ class EventLoopNode(NodeProtocol):
|
||||
self._spill_counter = max_n
|
||||
logger.info("Restored spill counter to %d from existing files", max_n)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JSON metadata / smart preview helpers for truncation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_metadata(parsed: Any, *, _depth: int = 0, _max_depth: int = 3) -> str:
|
||||
"""Return a concise structural summary of parsed JSON.
|
||||
|
||||
Reports key names, value types, and — crucially — array lengths so
|
||||
the LLM knows how much data exists beyond the preview.
|
||||
|
||||
Returns an empty string for simple scalars.
|
||||
"""
|
||||
if _depth >= _max_depth:
|
||||
if isinstance(parsed, dict):
|
||||
return f"dict with {len(parsed)} keys"
|
||||
if isinstance(parsed, list):
|
||||
return f"list of {len(parsed)} items"
|
||||
return type(parsed).__name__
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
if not parsed:
|
||||
return "empty dict"
|
||||
lines: list[str] = []
|
||||
indent = " " * (_depth + 1)
|
||||
for key, value in list(parsed.items())[:20]:
|
||||
if isinstance(value, list):
|
||||
line = f'{indent}"{key}": list of {len(value)} items'
|
||||
if value:
|
||||
first = value[0]
|
||||
if isinstance(first, dict):
|
||||
sample_keys = list(first.keys())[:10]
|
||||
line += f" (each item: dict with keys {sample_keys})"
|
||||
elif isinstance(first, list):
|
||||
line += f" (each item: list of {len(first)} elements)"
|
||||
lines.append(line)
|
||||
elif isinstance(value, dict):
|
||||
child = EventLoopNode._extract_json_metadata(
|
||||
value, _depth=_depth + 1, _max_depth=_max_depth
|
||||
)
|
||||
lines.append(f'{indent}"{key}": {child}')
|
||||
else:
|
||||
lines.append(f'{indent}"{key}": {type(value).__name__}')
|
||||
if len(parsed) > 20:
|
||||
lines.append(f"{indent}... and {len(parsed) - 20} more keys")
|
||||
return "\n".join(lines)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
if not parsed:
|
||||
return "empty list"
|
||||
desc = f"list of {len(parsed)} items"
|
||||
first = parsed[0]
|
||||
if isinstance(first, dict):
|
||||
sample_keys = list(first.keys())[:10]
|
||||
desc += f" (each item: dict with keys {sample_keys})"
|
||||
elif isinstance(first, list):
|
||||
desc += f" (each item: list of {len(first)} elements)"
|
||||
return desc
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _build_json_preview(parsed: Any, *, max_chars: int = 5000) -> str | None:
|
||||
"""Build a smart preview of parsed JSON, truncating large arrays.
|
||||
|
||||
Shows first 3 + last 1 items of large arrays with explicit count
|
||||
markers so the LLM cannot mistake the preview for the full dataset.
|
||||
|
||||
Returns ``None`` if no truncation was needed (no large arrays).
|
||||
"""
|
||||
_LARGE_ARRAY_THRESHOLD = 10
|
||||
|
||||
def _truncate_arrays(obj: Any) -> tuple[Any, bool]:
|
||||
"""Return (truncated_copy, was_truncated)."""
|
||||
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
|
||||
n = len(obj)
|
||||
head = obj[:3]
|
||||
tail = obj[-1:]
|
||||
marker = f"... ({n - 4} more items omitted, {n} total) ..."
|
||||
return head + [marker] + tail, True
|
||||
if isinstance(obj, dict):
|
||||
changed = False
|
||||
out: dict[str, Any] = {}
|
||||
for k, v in obj.items():
|
||||
new_v, did = _truncate_arrays(v)
|
||||
out[k] = new_v
|
||||
changed = changed or did
|
||||
return (out, True) if changed else (obj, False)
|
||||
return obj, False
|
||||
|
||||
preview_obj, was_truncated = _truncate_arrays(parsed)
|
||||
if not was_truncated:
|
||||
return None # No large arrays — caller should use raw slicing
|
||||
|
||||
try:
|
||||
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if len(result) > max_chars:
|
||||
# Even 3+1 items too big — try just 1 item
|
||||
def _minimal_arrays(obj: Any) -> Any:
|
||||
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
|
||||
n = len(obj)
|
||||
return obj[:1] + [f"... ({n - 1} more items omitted, {n} total) ..."]
|
||||
if isinstance(obj, dict):
|
||||
return {k: _minimal_arrays(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
preview_obj = _minimal_arrays(parsed)
|
||||
try:
|
||||
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if len(result) > max_chars:
|
||||
result = result[:max_chars] + "…"
|
||||
|
||||
return result
|
||||
|
||||
def _truncate_tool_result(
|
||||
self,
|
||||
result: ToolResult,
|
||||
@@ -3449,15 +3796,36 @@ class EventLoopNode(NodeProtocol):
|
||||
if tool_name == "load_data":
|
||||
if limit <= 0 or len(result.content) <= limit:
|
||||
return result # Small load_data result — pass through as-is
|
||||
# Large load_data result — truncate with pagination hint
|
||||
preview_chars = max(limit - 300, limit // 2)
|
||||
preview = result.content[:preview_chars]
|
||||
truncated = (
|
||||
f"[{tool_name} result: {len(result.content)} chars — "
|
||||
f"too large for context. Use offset/limit parameters "
|
||||
f"to read smaller chunks.]\n\n"
|
||||
f"Preview:\n{preview}…"
|
||||
# Large load_data result — truncate with smart preview
|
||||
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
|
||||
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
try:
|
||||
parsed_ld = json.loads(result.content)
|
||||
metadata_str = self._extract_json_metadata(parsed_ld)
|
||||
smart_preview = self._build_json_preview(parsed_ld, max_chars=PREVIEW_CAP)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
header = (
|
||||
f"[{tool_name} result: {len(result.content):,} chars — "
|
||||
f"too large for context. Use offset_bytes/limit_bytes "
|
||||
f"parameters to read smaller chunks.]"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\n\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
"\n\nWARNING: This is an INCOMPLETE preview. "
|
||||
"Do NOT draw conclusions or counts from it."
|
||||
)
|
||||
|
||||
truncated = f"{header}\n\nPreview (small sample only):\n{preview_block}"
|
||||
logger.info(
|
||||
"%s result truncated: %d → %d chars (use offset/limit to paginate)",
|
||||
tool_name,
|
||||
@@ -3479,25 +3847,47 @@ class EventLoopNode(NodeProtocol):
|
||||
# Pretty-print JSON content so load_data's line-based
|
||||
# pagination works correctly.
|
||||
write_content = result.content
|
||||
parsed_json: Any = None # track for metadata extraction
|
||||
try:
|
||||
parsed = json.loads(result.content)
|
||||
write_content = json.dumps(parsed, indent=2, ensure_ascii=False)
|
||||
parsed_json = json.loads(result.content)
|
||||
write_content = json.dumps(parsed_json, indent=2, ensure_ascii=False)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass # Not JSON — write as-is
|
||||
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
|
||||
if limit > 0 and len(result.content) > limit:
|
||||
# Large result: preview + file reference
|
||||
preview_chars = max(limit - 300, limit // 2)
|
||||
preview = result.content[:preview_chars]
|
||||
content = (
|
||||
f"[Result from {tool_name}: {len(result.content)} chars — "
|
||||
f"too large for context, saved to '{filename}'. "
|
||||
f"Use load_data(filename='{filename}') "
|
||||
f"to read the full result.]\n\n"
|
||||
f"Preview:\n{preview}…"
|
||||
# Large result: build a small, metadata-rich preview so the
|
||||
# LLM cannot mistake it for the complete dataset.
|
||||
PREVIEW_CAP = 5000
|
||||
|
||||
# Extract structural metadata (array lengths, key names)
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
if parsed_json is not None:
|
||||
metadata_str = self._extract_json_metadata(parsed_json)
|
||||
smart_preview = self._build_json_preview(parsed_json, max_chars=PREVIEW_CAP)
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
# Assemble header with structural info + warning
|
||||
header = (
|
||||
f"[Result from {tool_name}: {len(result.content):,} chars — "
|
||||
f"too large for context, saved to '{filename}'.]"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\n\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
f"\n\nWARNING: The preview below is INCOMPLETE. "
|
||||
f"Do NOT draw conclusions or counts from it. "
|
||||
f"Use load_data(filename='{filename}') to read the "
|
||||
f"full data before analysis."
|
||||
)
|
||||
|
||||
content = f"{header}\n\nPreview (small sample only):\n{preview_block}"
|
||||
logger.info(
|
||||
"Tool result spilled to file: %s (%d chars → %s)",
|
||||
tool_name,
|
||||
@@ -3522,13 +3912,34 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
# No spillover_dir — truncate in-place if needed
|
||||
if limit > 0 and len(result.content) > limit:
|
||||
preview_chars = max(limit - 300, limit // 2)
|
||||
preview = result.content[:preview_chars]
|
||||
truncated = (
|
||||
f"[Result from {tool_name}: {len(result.content)} chars — "
|
||||
f"truncated to fit context budget. Only the first "
|
||||
f"{preview_chars} chars are shown.]\n\n{preview}…"
|
||||
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
|
||||
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
try:
|
||||
parsed_inline = json.loads(result.content)
|
||||
metadata_str = self._extract_json_metadata(parsed_inline)
|
||||
smart_preview = self._build_json_preview(parsed_inline, max_chars=PREVIEW_CAP)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
header = (
|
||||
f"[Result from {tool_name}: {len(result.content):,} chars — "
|
||||
f"truncated to fit context budget.]"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\n\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
"\n\nWARNING: This is an INCOMPLETE preview. "
|
||||
"Do NOT draw conclusions or counts from the preview alone."
|
||||
)
|
||||
|
||||
truncated = f"{header}\n\n{preview_block}"
|
||||
logger.info(
|
||||
"Tool result truncated in-place: %s (%d → %d chars)",
|
||||
tool_name,
|
||||
@@ -3569,6 +3980,68 @@ class EventLoopNode(NodeProtocol):
|
||||
ratio_before = conversation.usage_ratio()
|
||||
phase_grad = getattr(ctx, "continuous_mode", False)
|
||||
|
||||
# Debug snapshot helper
|
||||
def _snap(name: str, **extra: Any) -> dict[str, Any]:
|
||||
roles: dict[str, int] = {}
|
||||
for m in conversation.messages:
|
||||
roles[m.role] = roles.get(m.role, 0) + 1
|
||||
return {
|
||||
"name": name,
|
||||
"message_count": conversation.message_count,
|
||||
"estimated_tokens": conversation.estimate_tokens(),
|
||||
"usage_ratio": f"{conversation.usage_ratio():.2%}",
|
||||
"max_context_tokens": self._config.max_context_tokens,
|
||||
"messages_by_role": roles,
|
||||
**extra,
|
||||
}
|
||||
|
||||
initial = _snap("initial")
|
||||
|
||||
# When over budget, attach a full message inventory so the log
|
||||
# shows exactly what is consuming the context window.
|
||||
if ratio_before >= 1.0:
|
||||
inventory: list[dict[str, Any]] = []
|
||||
for m in conversation.messages:
|
||||
content_chars = len(m.content)
|
||||
tc_chars = 0
|
||||
tool_name = None
|
||||
if m.tool_calls:
|
||||
for tc in m.tool_calls:
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
tc_chars += len(args) if isinstance(args, str) else len(json.dumps(args))
|
||||
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
|
||||
tool_name = ", ".join(names)
|
||||
elif m.role == "tool" and m.tool_use_id:
|
||||
# Try to find the tool name from the preceding assistant message
|
||||
for prev in conversation.messages:
|
||||
if prev.tool_calls:
|
||||
for tc in prev.tool_calls:
|
||||
if tc.get("id") == m.tool_use_id:
|
||||
tool_name = tc.get("function", {}).get("name", "?")
|
||||
break
|
||||
if tool_name:
|
||||
break
|
||||
entry: dict[str, Any] = {
|
||||
"seq": m.seq,
|
||||
"role": m.role,
|
||||
"content_chars": content_chars,
|
||||
}
|
||||
if tc_chars:
|
||||
entry["tool_call_args_chars"] = tc_chars
|
||||
if tool_name:
|
||||
entry["tool"] = tool_name
|
||||
if m.is_error:
|
||||
entry["is_error"] = True
|
||||
if m.phase_id:
|
||||
entry["phase"] = m.phase_id
|
||||
# Content preview for the biggest messages
|
||||
if content_chars > 2000:
|
||||
entry["preview"] = m.content[:200] + "…"
|
||||
inventory.append(entry)
|
||||
initial["message_inventory"] = inventory
|
||||
|
||||
debug_steps: list[dict[str, Any]] = [initial]
|
||||
|
||||
# --- Step 1: Prune old tool results (free, no LLM) ---
|
||||
protect = max(2000, self._config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
@@ -3582,8 +4055,10 @@ class EventLoopNode(NodeProtocol):
|
||||
ratio_before * 100,
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
debug_steps.append(_snap("after_prune", messages_pruned=pruned))
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
self._write_compaction_debug_log(ctx, debug_steps)
|
||||
return
|
||||
|
||||
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
|
||||
@@ -3595,8 +4070,14 @@ class EventLoopNode(NodeProtocol):
|
||||
keep_recent=4,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
debug_steps.append(_snap(
|
||||
"after_structural",
|
||||
spillover_dir=spill_dir,
|
||||
keep_recent=4,
|
||||
))
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
self._write_compaction_debug_log(ctx, debug_steps)
|
||||
return
|
||||
|
||||
# --- Step 3: LLM summary compaction ---
|
||||
@@ -3619,11 +4100,20 @@ class EventLoopNode(NodeProtocol):
|
||||
keep_recent=2,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
debug_steps.append(_snap(
|
||||
"after_llm_compact",
|
||||
summary_chars=len(summary),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning("LLM compaction failed: %s", e)
|
||||
debug_steps.append(_snap(
|
||||
"llm_compact_failed",
|
||||
error=str(e),
|
||||
))
|
||||
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
self._write_compaction_debug_log(ctx, debug_steps)
|
||||
return
|
||||
|
||||
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
|
||||
@@ -3637,7 +4127,12 @@ class EventLoopNode(NodeProtocol):
|
||||
keep_recent=1,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
debug_steps.append(_snap(
|
||||
"after_emergency",
|
||||
summary_chars=len(summary),
|
||||
))
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
self._write_compaction_debug_log(ctx, debug_steps)
|
||||
|
||||
# --- LLM compaction with binary-search splitting ----------------------
|
||||
|
||||
@@ -3851,6 +4346,91 @@ class EventLoopNode(NodeProtocol):
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _write_compaction_debug_log(
|
||||
ctx: NodeContext,
|
||||
steps: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Write detailed compaction analysis to ~/.hive/compaction_log/.
|
||||
|
||||
Only runs when HIVE_COMPACTION_DEBUG is set in the environment.
|
||||
Each compaction produces a timestamped markdown file.
|
||||
"""
|
||||
import os
|
||||
|
||||
if not os.environ.get("HIVE_COMPACTION_DEBUG"):
|
||||
return
|
||||
|
||||
log_dir = Path.home() / ".hive" / "compaction_log"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
|
||||
node_label = ctx.node_id.replace("/", "_")
|
||||
log_path = log_dir / f"{ts}_{node_label}.md"
|
||||
|
||||
lines: list[str] = []
|
||||
lines.append(f"# Compaction Debug — {ctx.node_id}")
|
||||
lines.append(f"**Time:** {datetime.now(UTC).isoformat()}")
|
||||
lines.append(f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)")
|
||||
if ctx.stream_id:
|
||||
lines.append(f"**Stream:** {ctx.stream_id}")
|
||||
lines.append("")
|
||||
|
||||
for step in steps:
|
||||
name = step.get("name", "unknown")
|
||||
lines.append(f"## Step: {name}")
|
||||
for key, val in step.items():
|
||||
if key == "name":
|
||||
continue
|
||||
if key == "messages_by_role":
|
||||
lines.append(f"- **{key}:**")
|
||||
for role, count in val.items():
|
||||
lines.append(f" - {role}: {count}")
|
||||
elif key == "message_inventory":
|
||||
total_chars = sum(e.get("content_chars", 0) + e.get("tool_call_args_chars", 0) for e in val)
|
||||
lines.append(f"### Message Inventory ({len(val)} messages, {total_chars:,} total chars)")
|
||||
lines.append("")
|
||||
# Sort descending by size for the table
|
||||
ranked = sorted(val, key=lambda e: e.get("content_chars", 0) + e.get("tool_call_args_chars", 0), reverse=True)
|
||||
lines.append("| # | seq | role | tool | chars | % of total | flags |")
|
||||
lines.append("|---|-----|------|------|------:|------------|-------|")
|
||||
for i, entry in enumerate(ranked, 1):
|
||||
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
|
||||
pct = (chars / total_chars * 100) if total_chars else 0
|
||||
tool = entry.get("tool", "")
|
||||
flags = []
|
||||
if entry.get("is_error"):
|
||||
flags.append("error")
|
||||
if entry.get("phase"):
|
||||
flags.append(f"phase={entry['phase']}")
|
||||
lines.append(
|
||||
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
|
||||
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
|
||||
)
|
||||
# Previews for large messages
|
||||
large = [e for e in ranked if e.get("preview")]
|
||||
if large:
|
||||
lines.append("")
|
||||
lines.append("#### Large message previews")
|
||||
for entry in large:
|
||||
lines.append(f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):")
|
||||
lines.append(f"```\n{entry['preview']}\n```")
|
||||
elif key == "discarded_messages":
|
||||
lines.append(f"- **{key}:** ({len(val)} messages)")
|
||||
for msg_info in val[:50]: # cap at 50
|
||||
lines.append(f" - seq={msg_info['seq']} role={msg_info['role']} chars={msg_info['chars']}")
|
||||
if len(val) > 50:
|
||||
lines.append(f" - ... and {len(val) - 50} more")
|
||||
else:
|
||||
lines.append(f"- **{key}:** {val}")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
log_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
logger.debug("Compaction debug log written to %s", log_path)
|
||||
except OSError:
|
||||
logger.debug("Failed to write compaction debug log to %s", log_path)
|
||||
|
||||
def _build_emergency_summary(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
@@ -3936,17 +4516,14 @@ class EventLoopNode(NodeProtocol):
|
||||
)
|
||||
parts.append(
|
||||
"CONVERSATION HISTORY (freeform messages saved during compaction — "
|
||||
"use load_data('<filename>'), read_file('<full_path>'), "
|
||||
"or run_command('cat \"<full_path>\"') to review earlier dialogue):\n"
|
||||
+ conv_list
|
||||
"use load_data('<filename>') to review earlier dialogue):\n" + conv_list
|
||||
)
|
||||
if data_files:
|
||||
file_list = "\n".join(
|
||||
f" - {f} (full path: {data_dir / f})" for f in data_files[:30]
|
||||
)
|
||||
parts.append(
|
||||
"DATA FILES (use load_data('<filename>'), read_file('<full_path>'), "
|
||||
"or run_command('cat \"<full_path>\"') to read):\n" + file_list
|
||||
"DATA FILES (use load_data('<filename>') to read):\n" + file_list
|
||||
)
|
||||
if not all_files:
|
||||
parts.append(
|
||||
@@ -4012,6 +4589,8 @@ class EventLoopNode(NodeProtocol):
|
||||
return None
|
||||
|
||||
accumulator = await OutputAccumulator.restore(self._conversation_store)
|
||||
accumulator.spillover_dir = self._config.spillover_dir
|
||||
accumulator.max_value_chars = self._config.max_output_value_chars
|
||||
|
||||
cursor = await self._conversation_store.read_cursor()
|
||||
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
|
||||
@@ -4603,11 +5182,19 @@ class EventLoopNode(NodeProtocol):
|
||||
subagent_tool_names = set(subagent_spec.tools or [])
|
||||
tool_source = ctx.all_tools if ctx.all_tools else ctx.available_tools
|
||||
|
||||
subagent_tools = [
|
||||
t
|
||||
for t in tool_source
|
||||
if t.name in subagent_tool_names and t.name != "delegate_to_sub_agent"
|
||||
]
|
||||
# GCU auto-population: GCU nodes declare tools=[] because the runner
|
||||
# auto-populates them at setup time. But that expansion doesn't reach
|
||||
# subagents invoked via delegate_to_sub_agent — the subagent spec still
|
||||
# has the original empty list. When a GCU subagent has no declared
|
||||
# tools, include all catalog tools so browser tools are available.
|
||||
if subagent_spec.node_type == "gcu" and not subagent_tool_names:
|
||||
subagent_tools = [t for t in tool_source if t.name != "delegate_to_sub_agent"]
|
||||
else:
|
||||
subagent_tools = [
|
||||
t
|
||||
for t in tool_source
|
||||
if t.name in subagent_tool_names and t.name != "delegate_to_sub_agent"
|
||||
]
|
||||
|
||||
missing = subagent_tool_names - {t.name for t in subagent_tools}
|
||||
if missing:
|
||||
|
||||
@@ -152,6 +152,8 @@ class GraphExecutor:
|
||||
dynamic_tools_provider: Callable | None = None,
|
||||
dynamic_prompt_provider: Callable | None = None,
|
||||
iteration_metadata_provider: Callable | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -177,6 +179,8 @@ class GraphExecutor:
|
||||
tool list (for mode switching)
|
||||
dynamic_prompt_provider: Optional callback returning current
|
||||
system prompt (for phase switching)
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -198,6 +202,20 @@ class GraphExecutor:
|
||||
self.dynamic_tools_provider = dynamic_tools_provider
|
||||
self.dynamic_prompt_provider = dynamic_prompt_provider
|
||||
self.iteration_metadata_provider = iteration_metadata_provider
|
||||
self.skills_catalog_prompt = skills_catalog_prompt
|
||||
self.protocols_prompt = protocols_prompt
|
||||
|
||||
if protocols_prompt:
|
||||
self.logger.info(
|
||||
"GraphExecutor[%s] received protocols_prompt (%d chars)",
|
||||
stream_id,
|
||||
len(protocols_prompt),
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"GraphExecutor[%s] received EMPTY protocols_prompt",
|
||||
stream_id,
|
||||
)
|
||||
|
||||
# Parallel execution settings
|
||||
self.enable_parallel_execution = enable_parallel_execution
|
||||
@@ -1402,6 +1420,7 @@ class GraphExecutor:
|
||||
next_spec = graph.get_node(current_node_id)
|
||||
if next_spec and next_spec.node_type == "event_loop":
|
||||
from framework.graph.prompt_composer import (
|
||||
EXECUTION_SCOPE_PREAMBLE,
|
||||
build_accounts_prompt,
|
||||
build_narrative,
|
||||
build_transition_marker,
|
||||
@@ -1441,9 +1460,14 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Compose new system prompt (Layer 1 + 2 + 3 + accounts)
|
||||
# Prepend scope preamble to focus so the LLM stays
|
||||
# within this node's responsibility.
|
||||
_focus = next_spec.system_prompt
|
||||
if next_spec.output_keys and _focus:
|
||||
_focus = f"{EXECUTION_SCOPE_PREAMBLE}\n\n{_focus}"
|
||||
new_system = compose_system_prompt(
|
||||
identity_prompt=getattr(graph, "identity_prompt", None),
|
||||
focus_prompt=next_spec.system_prompt,
|
||||
focus_prompt=_focus,
|
||||
narrative=narrative,
|
||||
accounts_prompt=_node_accounts,
|
||||
)
|
||||
@@ -1805,10 +1829,31 @@ class GraphExecutor:
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in self.tools if t.name in node_spec.tools]
|
||||
|
||||
# Create scoped memory view
|
||||
# Create scoped memory view.
|
||||
# When permissions are restricted (non-empty key lists), auto-include
|
||||
# _-prefixed keys used by default skill protocols so agents can read/write
|
||||
# operational state (e.g. _working_notes, _batch_ledger) regardless of
|
||||
# what the node declares. When key lists are empty (unrestricted), leave
|
||||
# unchanged — empty means "allow all".
|
||||
read_keys = list(node_spec.input_keys)
|
||||
write_keys = list(node_spec.output_keys)
|
||||
# Only extend lists that were already restricted (non-empty).
|
||||
# Empty means "allow all" — adding keys would accidentally
|
||||
# activate the permission check and block legitimate reads/writes.
|
||||
if read_keys or write_keys:
|
||||
from framework.skills.defaults import SHARED_MEMORY_KEYS as _skill_keys
|
||||
|
||||
existing_underscore = [k for k in memory._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
for k in extra_keys:
|
||||
if read_keys and k not in read_keys:
|
||||
read_keys.append(k)
|
||||
if write_keys and k not in write_keys:
|
||||
write_keys.append(k)
|
||||
|
||||
scoped_memory = memory.with_permissions(
|
||||
read_keys=node_spec.input_keys,
|
||||
write_keys=node_spec.output_keys,
|
||||
read_keys=read_keys,
|
||||
write_keys=write_keys,
|
||||
)
|
||||
|
||||
# Build per-node accounts prompt (filtered to this node's tools)
|
||||
@@ -1852,6 +1897,8 @@ class GraphExecutor:
|
||||
dynamic_tools_provider=self.dynamic_tools_provider,
|
||||
dynamic_prompt_provider=self.dynamic_prompt_provider,
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
|
||||
@@ -565,6 +565,10 @@ class NodeContext:
|
||||
# staging / running) without restarting the conversation.
|
||||
dynamic_prompt_provider: Any = None # Callable[[], str] | None
|
||||
|
||||
# Skill system prompts — injected by the skill discovery pipeline
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
protocols_prompt: str = "" # Default skill operational protocols
|
||||
|
||||
# Per-iteration metadata provider — when set, EventLoopNode merges
|
||||
# the returned dict into node_loop_iteration event data. Used by
|
||||
# the queen to record the current phase per iteration.
|
||||
|
||||
@@ -26,6 +26,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Injected into every worker node's system prompt so the LLM understands
|
||||
# it is one step in a multi-node pipeline and should not overreach.
|
||||
EXECUTION_SCOPE_PREAMBLE = (
|
||||
"EXECUTION SCOPE: You are one node in a multi-step workflow graph. "
|
||||
"Focus ONLY on the task described in your instructions below. "
|
||||
"Call set_output() for each of your declared output keys, then stop. "
|
||||
"Do NOT attempt work that belongs to other nodes — the framework "
|
||||
"routes data between nodes automatically."
|
||||
)
|
||||
|
||||
|
||||
def _with_datetime(prompt: str) -> str:
|
||||
"""Append current datetime with local timezone to a system prompt."""
|
||||
@@ -140,14 +150,18 @@ def compose_system_prompt(
|
||||
focus_prompt: str | None,
|
||||
narrative: str | None = None,
|
||||
accounts_prompt: str | None = None,
|
||||
skills_catalog_prompt: str | None = None,
|
||||
protocols_prompt: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the three-layer system prompt.
|
||||
"""Compose the multi-layer system prompt.
|
||||
|
||||
Args:
|
||||
identity_prompt: Layer 1 — static agent identity (from GraphSpec).
|
||||
focus_prompt: Layer 3 — per-node focus directive (from NodeSpec.system_prompt).
|
||||
narrative: Layer 2 — auto-generated from conversation state.
|
||||
accounts_prompt: Connected accounts block (sits between identity and narrative).
|
||||
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
|
||||
protocols_prompt: Default skill operational protocols section.
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present, plus current datetime.
|
||||
@@ -162,6 +176,14 @@ def compose_system_prompt(
|
||||
if accounts_prompt:
|
||||
parts.append(f"\n{accounts_prompt}")
|
||||
|
||||
# Skills catalog (discovered skills available for activation)
|
||||
if skills_catalog_prompt:
|
||||
parts.append(f"\n{skills_catalog_prompt}")
|
||||
|
||||
# Operational protocols (default skill behavioral guidance)
|
||||
if protocols_prompt:
|
||||
parts.append(f"\n{protocols_prompt}")
|
||||
|
||||
# Layer 2: Narrative (what's happened so far)
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
@@ -255,7 +277,9 @@ def build_transition_marker(
|
||||
sections.append(f"\nCompleted: {previous_node.name}")
|
||||
sections.append(f" {previous_node.description}")
|
||||
|
||||
# Outputs in memory
|
||||
# Outputs in memory — use file references for large values so the
|
||||
# next node loads full data from disk instead of seeing truncated
|
||||
# inline previews that look deceptively complete.
|
||||
all_memory = memory.read_all()
|
||||
if all_memory:
|
||||
memory_lines: list[str] = []
|
||||
@@ -263,7 +287,29 @@ def build_transition_marker(
|
||||
if value is None:
|
||||
continue
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
if len(val_str) > 300 and data_dir:
|
||||
# Auto-spill large transition values to data files
|
||||
import json as _json
|
||||
|
||||
data_path = Path(data_dir)
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
try:
|
||||
write_content = (
|
||||
_json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(data_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (data_path / filename).stat().st_size
|
||||
val_str = (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') to access.]"
|
||||
)
|
||||
except Exception:
|
||||
val_str = val_str[:300] + "..."
|
||||
elif len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
memory_lines.append(f" {key}: {val_str}")
|
||||
if memory_lines:
|
||||
@@ -280,7 +326,7 @@ def build_transition_marker(
|
||||
]
|
||||
if file_lines:
|
||||
sections.append(
|
||||
"\nData files (use read_file to access):\n" + "\n".join(file_lines)
|
||||
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
|
||||
)
|
||||
|
||||
# Agent working memory
|
||||
@@ -294,6 +340,12 @@ def build_transition_marker(
|
||||
# Next phase
|
||||
sections.append(f"\nNow entering: {next_node.name}")
|
||||
sections.append(f" {next_node.description}")
|
||||
if next_node.output_keys:
|
||||
sections.append(
|
||||
f"\nYour ONLY job in this phase: complete the task above and call "
|
||||
f"set_output() for {next_node.output_keys}. Do NOT do work that "
|
||||
f"belongs to later phases."
|
||||
)
|
||||
|
||||
# Reflection prompt (engineered metacognition)
|
||||
sections.append(
|
||||
|
||||
@@ -115,11 +115,23 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
return True
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
|
||||
values = [self.visit(v) for v in node.values]
|
||||
# Short-circuit evaluation to match Python semantics.
|
||||
# Previously all operands were eagerly evaluated, which broke
|
||||
# guard patterns like: ``x is not None and x.get("key")``
|
||||
if isinstance(node.op, ast.And):
|
||||
return all(values)
|
||||
result = True
|
||||
for v in node.values:
|
||||
result = self.visit(v)
|
||||
if not result:
|
||||
return result
|
||||
return result
|
||||
elif isinstance(node.op, ast.Or):
|
||||
return any(values)
|
||||
result = False
|
||||
for v in node.values:
|
||||
result = self.visit(v)
|
||||
if result:
|
||||
return result
|
||||
return result
|
||||
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
|
||||
|
||||
def visit_IfExp(self, node: ast.IfExp) -> Any:
|
||||
|
||||
@@ -7,9 +7,11 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
@@ -23,6 +25,7 @@ except ImportError:
|
||||
litellm = None # type: ignore[assignment]
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.config import HIVE_LLM_ENDPOINT as HIVE_API_BASE
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
@@ -134,6 +137,7 @@ RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
MINIMAX_API_BASE = "https://api.minimax.io/v1"
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
# Providers that accept cache_control on message content blocks.
|
||||
# Anthropic: native ephemeral caching. MiniMax & Z-AI/GLM: pass-through to their APIs.
|
||||
@@ -162,6 +166,18 @@ KIMI_API_BASE = "https://api.kimi.com/coding"
|
||||
# Conversation-structure issues are deterministic — long waits don't help.
|
||||
EMPTY_STREAM_MAX_RETRIES = 3
|
||||
EMPTY_STREAM_RETRY_DELAY = 1.0 # seconds
|
||||
OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS = (
|
||||
"no endpoints found that support tool use",
|
||||
"no endpoints available that support tool use",
|
||||
"provider routing",
|
||||
)
|
||||
OPENROUTER_TOOL_CALL_RE = re.compile(
|
||||
r"<\|tool_call_start\|>\s*(.*?)\s*<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS = 3600
|
||||
# OpenRouter routing can change over time, so tool-compat caching must expire.
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE: dict[str, float] = {}
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
@@ -204,6 +220,24 @@ def _prune_failed_request_dumps(max_files: int = MAX_FAILED_REQUEST_DUMPS) -> No
|
||||
pass # Best-effort — never block the caller
|
||||
|
||||
|
||||
def _remember_openrouter_tool_compat_model(model: str) -> None:
|
||||
"""Cache OpenRouter tool-compat fallback for a bounded time window."""
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE[model] = (
|
||||
time.monotonic() + OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS
|
||||
)
|
||||
|
||||
|
||||
def _is_openrouter_tool_compat_cached(model: str) -> bool:
|
||||
"""Return True when the cached OpenRouter compat entry is still fresh."""
|
||||
expires_at = OPENROUTER_TOOL_COMPAT_MODEL_CACHE.get(model)
|
||||
if expires_at is None:
|
||||
return False
|
||||
if expires_at <= time.monotonic():
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.pop(model, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _dump_failed_request(
|
||||
model: str,
|
||||
kwargs: dict[str, Any],
|
||||
@@ -399,6 +433,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Strip a trailing /v1 in case the user's saved config has the old value.
|
||||
if api_base and api_base.rstrip("/").endswith("/v1"):
|
||||
api_base = api_base.rstrip("/")[:-3]
|
||||
elif model.lower().startswith("hive/"):
|
||||
model = "anthropic/" + model[len("hive/") :]
|
||||
if api_base and api_base.rstrip("/").endswith("/v1"):
|
||||
api_base = api_base.rstrip("/")[:-3]
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or self._default_api_base_for_model(_original_model)
|
||||
@@ -426,8 +464,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
model_lower = model.lower()
|
||||
if model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
|
||||
return MINIMAX_API_BASE
|
||||
if model_lower.startswith("openrouter/"):
|
||||
return OPENROUTER_API_BASE
|
||||
if model_lower.startswith("kimi/"):
|
||||
return KIMI_API_BASE
|
||||
if model_lower.startswith("hive/"):
|
||||
return HIVE_API_BASE
|
||||
return None
|
||||
|
||||
def _completion_with_rate_limit_retry(
|
||||
@@ -832,6 +874,494 @@ class LiteLLMProvider(LLMProvider):
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("minimax/") or model.startswith("minimax-")
|
||||
|
||||
def _is_openrouter_model(self) -> bool:
|
||||
"""Return True when the configured model targets OpenRouter."""
|
||||
model = (self.model or "").lower()
|
||||
if model.startswith("openrouter/"):
|
||||
return True
|
||||
api_base = (self.api_base or "").lower()
|
||||
return "openrouter.ai/api/v1" in api_base
|
||||
|
||||
def _should_use_openrouter_tool_compat(
|
||||
self,
|
||||
error: BaseException,
|
||||
tools: list[Tool] | None,
|
||||
) -> bool:
|
||||
"""Return True when OpenRouter rejects native tool use for the model."""
|
||||
if not tools or not self._is_openrouter_model():
|
||||
return False
|
||||
error_text = str(error).lower()
|
||||
return "openrouter" in error_text and any(
|
||||
snippet in error_text for snippet in OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||
"""Extract the first JSON object from a model response."""
|
||||
candidates = [text.strip()]
|
||||
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
fence_lines = stripped.splitlines()
|
||||
if len(fence_lines) >= 3:
|
||||
candidates.append("\n".join(fence_lines[1:-1]).strip())
|
||||
|
||||
decoder = json.JSONDecoder()
|
||||
for candidate in candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
for start_idx, char in enumerate(candidate):
|
||||
if char != "{":
|
||||
continue
|
||||
try:
|
||||
parsed, _ = decoder.raw_decode(candidate[start_idx:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_openrouter_tool_compat_response(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse JSON tool-compat output into assistant text and tool calls."""
|
||||
payload = self._extract_json_object(content)
|
||||
if payload is None:
|
||||
text_tool_content, text_tool_calls = self._parse_openrouter_text_tool_calls(
|
||||
content,
|
||||
tools,
|
||||
)
|
||||
if text_tool_calls:
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Parsed textual tool-call markers for %s",
|
||||
self.model,
|
||||
)
|
||||
return text_tool_content, text_tool_calls
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] %s returned non-JSON fallback content; "
|
||||
"treating it as plain text.",
|
||||
self.model,
|
||||
)
|
||||
return content.strip(), []
|
||||
|
||||
assistant_text = payload.get("assistant_response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("content")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = ""
|
||||
|
||||
tool_calls_raw = payload.get("tool_calls")
|
||||
if not tool_calls_raw and {"name", "arguments"} <= payload.keys():
|
||||
tool_calls_raw = [payload]
|
||||
elif isinstance(payload.get("tool_call"), dict):
|
||||
tool_calls_raw = [payload["tool_call"]]
|
||||
|
||||
if not isinstance(tool_calls_raw, list):
|
||||
tool_calls_raw = []
|
||||
|
||||
allowed_tool_names = {tool.name for tool in tools}
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
|
||||
for idx, raw_call in enumerate(tool_calls_raw):
|
||||
if not isinstance(raw_call, dict):
|
||||
continue
|
||||
|
||||
function_block = raw_call.get("function")
|
||||
function_name = (
|
||||
raw_call.get("name")
|
||||
or raw_call.get("tool_name")
|
||||
or (function_block.get("name") if isinstance(function_block, dict) else None)
|
||||
)
|
||||
if not isinstance(function_name, str) or function_name not in allowed_tool_names:
|
||||
if function_name:
|
||||
logger.warning(
|
||||
"[openrouter-tool-compat] Ignoring unknown tool '%s' for model %s",
|
||||
function_name,
|
||||
self.model,
|
||||
)
|
||||
continue
|
||||
|
||||
arguments = raw_call.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("tool_input")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("input")
|
||||
if arguments is None and isinstance(function_block, dict):
|
||||
arguments = function_block.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"_raw": arguments}
|
||||
elif not isinstance(arguments, dict):
|
||||
arguments = {"value": arguments}
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{idx}",
|
||||
"name": function_name,
|
||||
"input": arguments,
|
||||
}
|
||||
)
|
||||
|
||||
return assistant_text.strip(), tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _close_truncated_json_fragment(fragment: str) -> str:
|
||||
"""Close a truncated JSON fragment by balancing quotes/brackets."""
|
||||
stack: list[str] = []
|
||||
in_string = False
|
||||
escaped = False
|
||||
normalized = fragment.rstrip()
|
||||
|
||||
while normalized and normalized[-1] in ",:{[":
|
||||
normalized = normalized[:-1].rstrip()
|
||||
|
||||
for char in normalized:
|
||||
if in_string:
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif char == "\\":
|
||||
escaped = True
|
||||
elif char == '"':
|
||||
in_string = False
|
||||
continue
|
||||
|
||||
if char == '"':
|
||||
in_string = True
|
||||
elif char in "{[":
|
||||
stack.append(char)
|
||||
elif char == "}" and stack and stack[-1] == "{":
|
||||
stack.pop()
|
||||
elif char == "]" and stack and stack[-1] == "[":
|
||||
stack.pop()
|
||||
|
||||
if in_string:
|
||||
if escaped:
|
||||
normalized = normalized[:-1]
|
||||
normalized += '"'
|
||||
|
||||
for opener in reversed(stack):
|
||||
normalized += "}" if opener == "{" else "]"
|
||||
|
||||
return normalized
|
||||
|
||||
def _repair_truncated_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None:
|
||||
"""Try to recover a truncated JSON object from tool-call arguments."""
|
||||
stripped = raw_arguments.strip()
|
||||
if not stripped or stripped[0] != "{":
|
||||
return None
|
||||
|
||||
max_trim = min(len(stripped), 256)
|
||||
for trim in range(max_trim + 1):
|
||||
candidate = stripped[: len(stripped) - trim].rstrip()
|
||||
if not candidate:
|
||||
break
|
||||
candidate = self._close_truncated_json_fragment(candidate)
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]:
|
||||
"""Parse streamed tool arguments, repairing truncation when possible."""
|
||||
try:
|
||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
repaired = self._repair_truncated_tool_arguments(raw_arguments)
|
||||
if repaired is not None:
|
||||
logger.warning(
|
||||
"[tool-args] Recovered truncated arguments for %s on %s",
|
||||
tool_name,
|
||||
self.model,
|
||||
)
|
||||
return repaired
|
||||
|
||||
raise ValueError(
|
||||
f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)."
|
||||
)
|
||||
|
||||
def _parse_openrouter_text_tool_calls(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse textual OpenRouter tool calls into synthetic tool calls.
|
||||
|
||||
Supports both:
|
||||
- Marker wrapped payloads: <|tool_call_start|>...<|tool_call_end|>
|
||||
- Plain one-line tool calls: ask_user("...", ["..."])
|
||||
"""
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
segment_index = 0
|
||||
|
||||
for match in OPENROUTER_TOOL_CALL_RE.finditer(content):
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=match.group(1),
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
|
||||
stripped_content = OPENROUTER_TOOL_CALL_RE.sub("", content)
|
||||
retained_lines: list[str] = []
|
||||
for line in stripped_content.splitlines():
|
||||
stripped_line = line.strip()
|
||||
if not stripped_line:
|
||||
retained_lines.append(line)
|
||||
continue
|
||||
|
||||
candidate = stripped_line
|
||||
if candidate.startswith("`") and candidate.endswith("`") and len(candidate) > 1:
|
||||
candidate = candidate[1:-1].strip()
|
||||
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=candidate,
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
continue
|
||||
|
||||
retained_lines.append(line)
|
||||
|
||||
stripped_text = "\n".join(retained_lines).strip()
|
||||
return stripped_text, tool_calls
|
||||
|
||||
def _parse_openrouter_text_tool_call_block(
|
||||
self,
|
||||
block: str,
|
||||
tools_by_name: dict[str, Tool],
|
||||
compat_prefix: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse a single textual tool-call block like [tool(arg='x')]."""
|
||||
try:
|
||||
parsed = ast.parse(block.strip(), mode="eval").body
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
call_nodes = parsed.elts if isinstance(parsed, ast.List) else [parsed]
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for call_index, call_node in enumerate(call_nodes):
|
||||
if not isinstance(call_node, ast.Call) or not isinstance(call_node.func, ast.Name):
|
||||
continue
|
||||
|
||||
tool_name = call_node.func.id
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_input = self._parse_openrouter_text_tool_call_arguments(
|
||||
call_node=call_node,
|
||||
tool=tool,
|
||||
)
|
||||
except (ValueError, SyntaxError):
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{call_index}",
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _parse_openrouter_text_tool_call_arguments(
|
||||
call_node: ast.Call,
|
||||
tool: Tool,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse positional/keyword args from a textual tool call."""
|
||||
properties = tool.parameters.get("properties", {})
|
||||
positional_keys = list(properties.keys())
|
||||
tool_input: dict[str, Any] = {}
|
||||
|
||||
if len(call_node.args) > len(positional_keys):
|
||||
raise ValueError("Too many positional args for textual tool call")
|
||||
|
||||
for idx, arg_node in enumerate(call_node.args):
|
||||
tool_input[positional_keys[idx]] = ast.literal_eval(arg_node)
|
||||
|
||||
for kwarg in call_node.keywords:
|
||||
if kwarg.arg is None:
|
||||
raise ValueError("Star args are not supported in textual tool calls")
|
||||
tool_input[kwarg.arg] = ast.literal_eval(kwarg.value)
|
||||
|
||||
return tool_input
|
||||
|
||||
def _build_openrouter_tool_compat_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a JSON-only prompt for models without native tool support."""
|
||||
tool_specs = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
compat_instruction = (
|
||||
"Tool compatibility mode is active because this OpenRouter model does not support "
|
||||
"native function calling on the routed provider.\n"
|
||||
"Return exactly one JSON object and nothing else.\n"
|
||||
'Schema: {"assistant_response": string, '
|
||||
'"tool_calls": [{"name": string, "arguments": object}]}\n'
|
||||
"Rules:\n"
|
||||
"- If a tool is required, put one or more entries in tool_calls "
|
||||
"and do not invent tool results.\n"
|
||||
"- If no tool is required, set tool_calls to [] and put the full "
|
||||
"answer in assistant_response.\n"
|
||||
"- Only use tool names from the allowed tool list.\n"
|
||||
"- arguments must always be valid JSON objects.\n"
|
||||
f"Allowed tools:\n{json.dumps(tool_specs, ensure_ascii=True)}"
|
||||
)
|
||||
compat_system = compat_instruction if not system else f"{system}\n\n{compat_instruction}"
|
||||
|
||||
full_messages: list[dict[str, Any]] = [{"role": "system", "content": compat_system}]
|
||||
full_messages.extend(messages)
|
||||
return [
|
||||
message
|
||||
for message in full_messages
|
||||
if not (
|
||||
message.get("role") == "assistant"
|
||||
and not message.get("content")
|
||||
and not message.get("tool_calls")
|
||||
)
|
||||
]
|
||||
|
||||
async def _acomplete_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools."""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(messages, system, tools)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
assistant_text, tool_calls = self._parse_openrouter_tool_compat_response(
|
||||
raw_content,
|
||||
tools,
|
||||
)
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
stop_reason = "tool_calls" if tool_calls else (response.choices[0].finish_reason or "stop")
|
||||
|
||||
return LLMResponse(
|
||||
content=assistant_text,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={
|
||||
"compat_mode": "openrouter_tool_emulation",
|
||||
"tool_calls": tool_calls,
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
|
||||
async def _stream_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback stream for OpenRouter models without native tool support."""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Using compatibility mode for %s",
|
||||
self.model,
|
||||
)
|
||||
try:
|
||||
response = await self._acomplete_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
raw_response = response.raw_response if isinstance(response.raw_response, dict) else {}
|
||||
tool_calls = raw_response.get("tool_calls", [])
|
||||
|
||||
if response.content:
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=tool_call["id"],
|
||||
tool_name=tool_call["name"],
|
||||
tool_input=tool_call["input"],
|
||||
)
|
||||
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _stream_via_nonstream_completion(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -875,12 +1405,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
tool_calls = msg.tool_calls or []
|
||||
|
||||
for tc in tool_calls:
|
||||
parsed_args: Any
|
||||
args = tc.function.arguments if tc.function else ""
|
||||
try:
|
||||
parsed_args = json.loads(args) if args else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed_args = {"_raw": args}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
args,
|
||||
tc.function.name if tc.function else "",
|
||||
)
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=getattr(tc, "id", ""),
|
||||
tool_name=tc.function.name if tc.function else "",
|
||||
@@ -939,6 +1468,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
yield event
|
||||
return
|
||||
|
||||
if tools and self._is_openrouter_model() and _is_openrouter_tool_compat_cached(self.model):
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
@@ -1085,10 +1624,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
if choice.finish_reason:
|
||||
stream_finish_reason = choice.finish_reason
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
tc_data.get("arguments", ""),
|
||||
tc_data.get("name", ""),
|
||||
)
|
||||
tail_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
@@ -1269,6 +1808,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if self._should_use_openrouter_tool_compat(e, tools):
|
||||
_remember_openrouter_tool_compat_model(self.model)
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools or [],
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
if _is_stream_transient_error(e) and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
|
||||
@@ -206,6 +206,20 @@ def configure_logging(
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(level.upper())
|
||||
|
||||
# Suppress noisy LiteLLM INFO logs (model/provider line + Provider List URL
|
||||
# printed on every single completion call). Warnings and errors still show.
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
|
||||
# Suppress the "Provider List: ..." banner litellm prints to stdout via
|
||||
# print() on every completion call. This is independent of log format.
|
||||
try:
|
||||
import litellm as _litellm
|
||||
|
||||
if hasattr(_litellm, "suppress_debug_info"):
|
||||
_litellm.suppress_debug_info = True # type: ignore[attr-defined]
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# When in JSON mode, configure known third-party loggers to use JSON formatter
|
||||
# This ensures libraries like LiteLLM, httpcore also output clean JSON
|
||||
if format == "json":
|
||||
@@ -228,16 +242,6 @@ def _disable_third_party_colors() -> None:
|
||||
os.environ["NO_COLOR"] = "1"
|
||||
os.environ["FORCE_COLOR"] = "0"
|
||||
|
||||
# Disable LiteLLM debug/verbose output colors if available
|
||||
try:
|
||||
import litellm
|
||||
|
||||
# LiteLLM respects NO_COLOR, but we can also suppress debug info
|
||||
if hasattr(litellm, "suppress_debug_info"):
|
||||
litellm.suppress_debug_info = True # type: ignore[attr-defined]
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
def set_trace_context(**kwargs: Any) -> None:
|
||||
"""
|
||||
|
||||
@@ -962,6 +962,9 @@ class AgentRunner:
|
||||
|
||||
# Generate flowchart.json if missing (for template/legacy agents)
|
||||
generate_fallback_flowchart(graph, goal, agent_path)
|
||||
# Read skill configuration from agent module
|
||||
agent_default_skills = getattr(agent_module, "default_skills", None)
|
||||
agent_skills = getattr(agent_module, "skills", None)
|
||||
|
||||
# Read runtime config (webhook settings, etc.) if defined
|
||||
agent_runtime_config = getattr(agent_module, "runtime_config", None)
|
||||
@@ -974,7 +977,7 @@ class AgentRunner:
|
||||
configure_fn = getattr(agent_module, "configure_for_account", None)
|
||||
list_accts_fn = getattr(agent_module, "list_connected_accounts", None)
|
||||
|
||||
return cls(
|
||||
runner = cls(
|
||||
agent_path=agent_path,
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
@@ -990,6 +993,10 @@ class AgentRunner:
|
||||
list_accounts=list_accts_fn,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
# Stash skill config for use in _setup()
|
||||
runner._agent_default_skills = agent_default_skills
|
||||
runner._agent_skills = agent_skills
|
||||
return runner
|
||||
|
||||
# Fallback: load from agent.json (legacy JSON-based agents)
|
||||
agent_json_path = agent_path / "agent.json"
|
||||
@@ -1010,7 +1017,7 @@ class AgentRunner:
|
||||
# Generate flowchart.json if missing (for legacy JSON-based agents)
|
||||
generate_fallback_flowchart(graph, goal, agent_path)
|
||||
|
||||
return cls(
|
||||
runner = cls(
|
||||
agent_path=agent_path,
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
@@ -1021,6 +1028,9 @@ class AgentRunner:
|
||||
skip_credential_validation=skip_credential_validation or False,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
runner._agent_default_skills = None
|
||||
runner._agent_skills = None
|
||||
return runner
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
@@ -1330,6 +1340,19 @@ class AgentRunner:
|
||||
except Exception:
|
||||
pass # Best-effort — agent works without account info
|
||||
|
||||
# Skill configuration — the runtime handles discovery, loading, and
|
||||
# prompt rasterization. The runner just builds the config.
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
|
||||
skills_manager_config = SkillsManagerConfig(
|
||||
skills_config=SkillsConfig.from_agent_vars(
|
||||
default_skills=getattr(self, "_agent_default_skills", None),
|
||||
skills=getattr(self, "_agent_skills", None),
|
||||
),
|
||||
project_root=self.agent_path,
|
||||
)
|
||||
|
||||
self._setup_agent_runtime(
|
||||
tools,
|
||||
tool_executor,
|
||||
@@ -1337,6 +1360,7 @@ class AgentRunner:
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
)
|
||||
|
||||
def _get_api_key_env_var(self, model: str) -> str | None:
|
||||
@@ -1357,6 +1381,8 @@ class AgentRunner:
|
||||
return "MISTRAL_API_KEY"
|
||||
elif model_lower.startswith("groq/"):
|
||||
return "GROQ_API_KEY"
|
||||
elif model_lower.startswith("openrouter/"):
|
||||
return "OPENROUTER_API_KEY"
|
||||
elif self._is_local_model(model_lower):
|
||||
return None # Local models don't need an API key
|
||||
elif model_lower.startswith("azure/"):
|
||||
@@ -1371,6 +1397,8 @@ class AgentRunner:
|
||||
return "MINIMAX_API_KEY"
|
||||
elif model_lower.startswith("kimi/"):
|
||||
return "KIMI_API_KEY"
|
||||
elif model_lower.startswith("hive/"):
|
||||
return "HIVE_API_KEY"
|
||||
else:
|
||||
# Default: assume OpenAI-compatible
|
||||
return "OPENAI_API_KEY"
|
||||
@@ -1393,6 +1421,8 @@ class AgentRunner:
|
||||
cred_id = "minimax"
|
||||
elif model_lower.startswith("kimi/"):
|
||||
cred_id = "kimi"
|
||||
elif model_lower.startswith("hive/"):
|
||||
cred_id = "hive"
|
||||
# Add more mappings as providers are added to LLM_CREDENTIALS
|
||||
|
||||
if cred_id is None:
|
||||
@@ -1432,6 +1462,7 @@ class AgentRunner:
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus=None,
|
||||
skills_manager_config=None,
|
||||
) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
entry_points = []
|
||||
@@ -1491,6 +1522,7 @@ class AgentRunner:
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
)
|
||||
|
||||
# Pass intro_message through for TUI display
|
||||
|
||||
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,6 +133,10 @@ class AgentRuntime:
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus: "EventBus | None" = None,
|
||||
skills_manager_config: "SkillsManagerConfig | None" = None,
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize agent runtime.
|
||||
@@ -153,7 +158,13 @@ class AgentRuntime:
|
||||
event_bus: Optional external EventBus. If provided, the runtime shares
|
||||
this bus instead of creating its own. Used by SessionManager to
|
||||
share a single bus between queen, worker, and judge.
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
protocols_prompt: Deprecated. Pre-rendered operational protocols.
|
||||
"""
|
||||
from framework.skills.manager import SkillsManager
|
||||
|
||||
self.graph = graph
|
||||
self.goal = goal
|
||||
self._config = config or AgentRuntimeConfig()
|
||||
@@ -161,6 +172,29 @@ class AgentRuntime:
|
||||
self._checkpoint_config = checkpoint_config
|
||||
self.accounts_prompt = accounts_prompt
|
||||
|
||||
# --- Skill lifecycle: runtime owns the SkillsManager ---
|
||||
if skills_manager_config is not None:
|
||||
# New path: config-driven, runtime handles loading
|
||||
self._skills_manager = SkillsManager(skills_manager_config)
|
||||
self._skills_manager.load()
|
||||
elif skills_catalog_prompt or protocols_prompt:
|
||||
# Legacy path: caller passed pre-rendered strings
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Passing pre-rendered skills_catalog_prompt/protocols_prompt "
|
||||
"is deprecated. Pass skills_manager_config instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._skills_manager = SkillsManager.from_precomputed(
|
||||
skills_catalog_prompt, protocols_prompt
|
||||
)
|
||||
else:
|
||||
# Bare constructor: auto-load defaults
|
||||
self._skills_manager = SkillsManager()
|
||||
self._skills_manager.load()
|
||||
|
||||
# Primary graph identity
|
||||
self._graph_id: str = graph_id or "primary"
|
||||
|
||||
@@ -216,6 +250,18 @@ class AgentRuntime:
|
||||
# Optional greeting shown to user on TUI load (set by AgentRunner)
|
||||
self.intro_message: str = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Skill prompt accessors (read by ExecutionStream constructors)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def skills_catalog_prompt(self) -> str:
|
||||
return self._skills_manager.skills_catalog_prompt
|
||||
|
||||
@property
|
||||
def protocols_prompt(self) -> str:
|
||||
return self._skills_manager.protocols_prompt
|
||||
|
||||
def register_entry_point(self, spec: EntryPointSpec) -> None:
|
||||
"""
|
||||
Register a named entry point for the agent.
|
||||
@@ -293,6 +339,8 @@ class AgentRuntime:
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
@@ -393,7 +441,8 @@ class AgentRuntime:
|
||||
|
||||
tc = spec.trigger_config
|
||||
cron_expr = tc.get("cron")
|
||||
interval = tc.get("interval_minutes")
|
||||
_raw_interval = tc.get("interval_minutes")
|
||||
interval = float(_raw_interval) if _raw_interval is not None else None
|
||||
run_immediately = tc.get("run_immediately", False)
|
||||
|
||||
if cron_expr:
|
||||
@@ -549,7 +598,7 @@ class AgentRuntime:
|
||||
ep_id,
|
||||
cron_expr,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
idle_timeout=float(tc.get("idle_timeout_seconds", 300)),
|
||||
)()
|
||||
)
|
||||
self._timer_tasks.append(task)
|
||||
@@ -679,7 +728,7 @@ class AgentRuntime:
|
||||
ep_id,
|
||||
interval,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
idle_timeout=float(tc.get("idle_timeout_seconds", 300)),
|
||||
)()
|
||||
)
|
||||
self._timer_tasks.append(task)
|
||||
@@ -926,6 +975,8 @@ class AgentRuntime:
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
)
|
||||
if self._running:
|
||||
await stream.start()
|
||||
@@ -1004,7 +1055,8 @@ class AgentRuntime:
|
||||
if spec.trigger_type != "timer":
|
||||
continue
|
||||
tc = spec.trigger_config
|
||||
interval = tc.get("interval_minutes")
|
||||
_raw_interval = tc.get("interval_minutes")
|
||||
interval = float(_raw_interval) if _raw_interval is not None else None
|
||||
run_immediately = tc.get("run_immediately", False)
|
||||
|
||||
if interval and interval > 0 and self._running:
|
||||
@@ -1149,7 +1201,7 @@ class AgentRuntime:
|
||||
ep_id,
|
||||
interval,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
idle_timeout=float(tc.get("idle_timeout_seconds", 300)),
|
||||
)()
|
||||
)
|
||||
timer_tasks.append(task)
|
||||
@@ -1704,6 +1756,10 @@ def create_agent_runtime(
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus: "EventBus | None" = None,
|
||||
skills_manager_config: "SkillsManagerConfig | None" = None,
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
) -> AgentRuntime:
|
||||
"""
|
||||
Create and configure an AgentRuntime with entry points.
|
||||
@@ -1730,6 +1786,10 @@ def create_agent_runtime(
|
||||
accounts_data: Raw account data for per-node prompt generation.
|
||||
tool_provider_map: Tool name to provider name mapping for account routing.
|
||||
event_bus: Optional external EventBus to share with other components.
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
protocols_prompt: Deprecated. Pre-rendered operational protocols.
|
||||
|
||||
Returns:
|
||||
Configured AgentRuntime (not yet started)
|
||||
@@ -1756,6 +1816,9 @@ def create_agent_runtime(
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
skills_catalog_prompt=skills_catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
|
||||
@@ -186,6 +186,8 @@ class ExecutionStream:
|
||||
accounts_prompt: str = "",
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -209,6 +211,8 @@ class ExecutionStream:
|
||||
accounts_prompt: Connected accounts block for system prompt injection
|
||||
accounts_data: Raw account data for per-node prompt generation
|
||||
tool_provider_map: Tool name to provider name mapping for account routing
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
@@ -230,6 +234,21 @@ class ExecutionStream:
|
||||
self._accounts_prompt = accounts_prompt
|
||||
self._accounts_data = accounts_data
|
||||
self._tool_provider_map = tool_provider_map
|
||||
self._skills_catalog_prompt = skills_catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
|
||||
_es_logger = logging.getLogger(__name__)
|
||||
if protocols_prompt:
|
||||
_es_logger.info(
|
||||
"ExecutionStream[%s] received protocols_prompt (%d chars)",
|
||||
stream_id,
|
||||
len(protocols_prompt),
|
||||
)
|
||||
else:
|
||||
_es_logger.warning(
|
||||
"ExecutionStream[%s] received EMPTY protocols_prompt",
|
||||
stream_id,
|
||||
)
|
||||
|
||||
# Create stream-scoped runtime
|
||||
self._runtime = StreamRuntime(
|
||||
@@ -675,6 +694,8 @@ class ExecutionStream:
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self._skills_catalog_prompt,
|
||||
protocols_prompt=self._protocols_prompt,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
@@ -69,6 +69,7 @@ async def create_queen(
|
||||
QueenPhaseState,
|
||||
register_queen_lifecycle_tools,
|
||||
)
|
||||
from framework.tools.queen_memory_tools import register_queen_memory_tools
|
||||
|
||||
hive_home = Path.home() / ".hive"
|
||||
|
||||
@@ -122,6 +123,9 @@ async def create_queen(
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
# ---- Episodic memory tools (always registered) ---------------------
|
||||
register_queen_memory_tools(queen_registry)
|
||||
|
||||
# ---- Monitoring tools (only when worker is loaded) ----------------
|
||||
if session.worker_runtime:
|
||||
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
|
||||
@@ -216,6 +220,16 @@ async def create_queen(
|
||||
+ worker_identity
|
||||
)
|
||||
|
||||
# ---- Default skill protocols -------------------------------------
|
||||
try:
|
||||
from framework.skills.manager import SkillsManager
|
||||
|
||||
_queen_skills_mgr = SkillsManager()
|
||||
_queen_skills_mgr.load()
|
||||
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
|
||||
except Exception:
|
||||
logger.debug("Queen skill loading failed (non-fatal)", exc_info=True)
|
||||
|
||||
# ---- Persona hook ------------------------------------------------
|
||||
_session_llm = session.llm
|
||||
_session_event_bus = session.event_bus
|
||||
|
||||
@@ -47,6 +47,8 @@ class Session:
|
||||
worker_handoff_sub: str | None = None
|
||||
# Memory consolidation subscription (fires on CONTEXT_COMPACTED)
|
||||
memory_consolidation_sub: str | None = None
|
||||
# Worker run digest subscription (fires on EXECUTION_COMPLETED / EXECUTION_FAILED)
|
||||
worker_digest_sub: str | None = None
|
||||
# Trigger definitions loaded from agent's triggers.json (available but inactive)
|
||||
available_triggers: dict[str, TriggerDefinition] = field(default_factory=dict)
|
||||
# Active trigger tracking (IDs currently firing + their asyncio tasks)
|
||||
@@ -177,6 +179,31 @@ class SessionManager:
|
||||
agent_path = Path(agent_path)
|
||||
resolved_worker_id = agent_id or agent_path.name
|
||||
|
||||
# When cold-restoring, check meta.json for the phase — if the agent
|
||||
# was still being built we must NOT try to load the worker (the code
|
||||
# is incomplete and will fail to import).
|
||||
if queen_resume_from:
|
||||
_resume_phase = None
|
||||
_meta_path = (
|
||||
Path.home() / ".hive" / "queen" / "session" / queen_resume_from / "meta.json"
|
||||
)
|
||||
if _meta_path.exists():
|
||||
try:
|
||||
_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
|
||||
_resume_phase = _meta.get("phase")
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
if _resume_phase in ("building", "planning"):
|
||||
# Fall back to queen-only session — cold resume handler in
|
||||
# _start_queen will set phase_state.agent_path and switch to
|
||||
# the correct phase.
|
||||
return await self.create_session(
|
||||
session_id=session_id,
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
)
|
||||
|
||||
# Reuse the original session ID when cold-restoring so the frontend
|
||||
# sees one continuous session instead of a new one each time.
|
||||
session = await self._create_session_core(
|
||||
@@ -193,6 +220,9 @@ class SessionManager:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Restore active triggers from persisted state (cold restore)
|
||||
await self._restore_active_triggers(session, session.id)
|
||||
|
||||
# Start queen with worker profile + lifecycle + monitoring tools
|
||||
worker_identity = (
|
||||
build_worker_profile(session.worker_runtime, agent_path=agent_path)
|
||||
@@ -204,7 +234,23 @@ class SessionManager:
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# If anything fails, tear down the session
|
||||
if queen_resume_from:
|
||||
# Cold restore: worker load failed (e.g. incomplete code from a
|
||||
# building session). Fall back to queen-only so the user can
|
||||
# continue the conversation and fix / rebuild the agent.
|
||||
logger.warning(
|
||||
"Cold restore: worker load failed for '%s', falling back to queen-only",
|
||||
agent_path,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.stop_session(session.id)
|
||||
return await self.create_session(
|
||||
session_id=session_id,
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
)
|
||||
# If anything fails (non-cold-restore), tear down the session
|
||||
await self.stop_session(session.id)
|
||||
raise
|
||||
return session
|
||||
@@ -297,6 +343,9 @@ class SessionManager:
|
||||
session.worker_runtime = runtime
|
||||
session.worker_info = info
|
||||
|
||||
# Subscribe to execution completion for per-run digest generation
|
||||
self._subscribe_worker_digest(session)
|
||||
|
||||
async with self._lock:
|
||||
self._loading.discard(session.id)
|
||||
|
||||
@@ -399,6 +448,51 @@ class SessionManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _restore_active_triggers(self, session: "Session", session_id: str) -> None:
|
||||
"""Restore previously active triggers from persisted session state.
|
||||
|
||||
Called after worker loading to restart any timer/webhook triggers
|
||||
that were active before a server restart.
|
||||
"""
|
||||
if not session.available_triggers or not session.worker_runtime:
|
||||
return
|
||||
try:
|
||||
store = session.worker_runtime._session_store
|
||||
state = await store.read_state(session_id)
|
||||
if state and state.active_triggers:
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
|
||||
for tid in state.active_triggers:
|
||||
tdef = session.available_triggers.get(tid)
|
||||
if tdef:
|
||||
# Restore user-configured task override
|
||||
saved_task = saved_tasks.get(tid, "")
|
||||
if saved_task:
|
||||
tdef.task = saved_task
|
||||
tdef.active = True
|
||||
session.active_trigger_ids.add(tid)
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, tid, tdef)
|
||||
logger.info("Restored trigger timer '%s'", tid)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, tid, tdef)
|
||||
logger.info("Restored webhook trigger '%s'", tid)
|
||||
else:
|
||||
logger.warning(
|
||||
"Saved trigger '%s' not found in worker entry points, skipping",
|
||||
tid,
|
||||
)
|
||||
|
||||
# Restore worker_configured flag
|
||||
if state and getattr(state, "worker_configured", False):
|
||||
session.worker_configured = True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore active triggers: %s", e)
|
||||
|
||||
async def load_worker(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -447,44 +541,7 @@ class SessionManager:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Restore previously active triggers from persisted session state
|
||||
if session.available_triggers and session.worker_runtime:
|
||||
try:
|
||||
store = session.worker_runtime._session_store
|
||||
state = await store.read_state(session_id)
|
||||
if state and state.active_triggers:
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
|
||||
for tid in state.active_triggers:
|
||||
tdef = session.available_triggers.get(tid)
|
||||
if tdef:
|
||||
# Restore user-configured task override
|
||||
saved_task = saved_tasks.get(tid, "")
|
||||
if saved_task:
|
||||
tdef.task = saved_task
|
||||
tdef.active = True
|
||||
session.active_trigger_ids.add(tid)
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, tid, tdef)
|
||||
logger.info("Restored trigger timer '%s'", tid)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, tid, tdef)
|
||||
logger.info("Restored webhook trigger '%s'", tid)
|
||||
else:
|
||||
logger.warning(
|
||||
"Saved trigger '%s' not found in worker entry points, skipping",
|
||||
tid,
|
||||
)
|
||||
|
||||
# Restore worker_configured flag
|
||||
if state and getattr(state, "worker_configured", False):
|
||||
session.worker_configured = True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore active triggers: %s", e)
|
||||
await self._restore_active_triggers(session, session_id)
|
||||
|
||||
# Emit SSE event so the frontend can update UI
|
||||
await self._emit_worker_loaded(session)
|
||||
@@ -526,6 +583,13 @@ class SessionManager:
|
||||
await self._emit_trigger_events(session, "removed", session.available_triggers)
|
||||
session.available_triggers.clear()
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
worker_id = session.worker_id
|
||||
session.worker_id = None
|
||||
session.worker_path = None
|
||||
@@ -563,6 +627,13 @@ class SessionManager:
|
||||
pass
|
||||
session.worker_handoff_sub = None
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
# Stop queen and memory consolidation subscription
|
||||
if session.memory_consolidation_sub is not None:
|
||||
try:
|
||||
@@ -647,6 +718,134 @@ class SessionManager:
|
||||
else:
|
||||
logger.warning("Worker handoff received but queen node not ready")
|
||||
|
||||
def _subscribe_worker_digest(self, session: Session) -> None:
|
||||
"""Subscribe to worker events to write per-run digests.
|
||||
|
||||
Three triggers:
|
||||
- NODE_LOOP_ITERATION: write a mid-run snapshot, throttled to at most
|
||||
once every _DIGEST_COOLDOWN seconds per execution.
|
||||
- TOOL_CALL_COMPLETED for delegate_to_sub_agent: same throttled snapshot.
|
||||
Orchestrator nodes often run all subagent calls in a single LLM turn,
|
||||
so NODE_LOOP_ITERATION only fires once at the end. Subagent
|
||||
completions provide intermediate checkpoints.
|
||||
- EXECUTION_COMPLETED / EXECUTION_FAILED: always write the final digest,
|
||||
bypassing the cooldown.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
_DIGEST_COOLDOWN = 300.0 # seconds between mid-run snapshots
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
agent_name = session.worker_path.name if session.worker_path else None
|
||||
if not agent_name:
|
||||
return
|
||||
|
||||
_agent_name = agent_name
|
||||
_llm = session.llm
|
||||
_bus = session.event_bus
|
||||
# per-execution_id monotonic timestamp of last mid-run digest
|
||||
_last_digest: dict[str, float] = {}
|
||||
|
||||
def _resolve_run_id(exec_id: str) -> str | None:
|
||||
"""Look up the run_id for a given execution_id via EXECUTION_STARTED history."""
|
||||
for e in _bus.get_history(event_type=_ET.EXECUTION_STARTED, limit=200):
|
||||
if e.execution_id == exec_id and getattr(e, "run_id", None):
|
||||
return e.run_id
|
||||
return None
|
||||
|
||||
async def _inject_digest_to_queen(run_id: str) -> None:
|
||||
"""Read the written digest and push it into the queen's conversation."""
|
||||
from framework.agents.worker_memory import digest_path
|
||||
|
||||
try:
|
||||
content = digest_path(_agent_name, run_id).read_text(encoding="utf-8").strip()
|
||||
except OSError:
|
||||
return
|
||||
if not content:
|
||||
return
|
||||
executor = session.queen_executor
|
||||
if executor is None:
|
||||
return
|
||||
node = executor.node_registry.get("queen")
|
||||
if node is None or not hasattr(node, "inject_event"):
|
||||
return
|
||||
await node.inject_event(f"[WORKER_DIGEST]\n{content}")
|
||||
|
||||
async def _consolidate_and_notify(run_id: str, outcome_event: Any) -> None:
|
||||
"""Write the digest then push it to the queen."""
|
||||
from framework.agents.worker_memory import consolidate_worker_run
|
||||
|
||||
await consolidate_worker_run(_agent_name, run_id, outcome_event, _bus, _llm)
|
||||
await _inject_digest_to_queen(run_id)
|
||||
|
||||
async def _on_worker_event(event: Any) -> None:
|
||||
if event.stream_id == "queen":
|
||||
return
|
||||
|
||||
exec_id = event.execution_id
|
||||
|
||||
if event.type == _ET.EXECUTION_STARTED:
|
||||
# New run on this execution_id — reset cooldown so the first
|
||||
# iteration always produces a mid-run snapshot.
|
||||
if exec_id:
|
||||
_last_digest.pop(exec_id, None)
|
||||
|
||||
elif event.type in (
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
_ET.EXECUTION_FAILED,
|
||||
_ET.EXECUTION_PAUSED,
|
||||
):
|
||||
# Final digest — always fire, ignore cooldown.
|
||||
# EXECUTION_PAUSED covers cancellation (queen re-triggering the
|
||||
# worker cancels the previous execution, emitting paused).
|
||||
run_id = getattr(event, "run_id", None) or _resolve_run_id(exec_id)
|
||||
if run_id:
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, event),
|
||||
name=f"worker-digest-final-{run_id}",
|
||||
)
|
||||
|
||||
elif event.type in (_ET.NODE_LOOP_ITERATION, _ET.TOOL_CALL_COMPLETED):
|
||||
# Mid-run snapshot — respect 300 s cooldown per execution.
|
||||
# TOOL_CALL_COMPLETED is only interesting for subagent calls;
|
||||
# regular tool completions are too frequent and too cheap.
|
||||
if event.type == _ET.TOOL_CALL_COMPLETED:
|
||||
tool_name = (event.data or {}).get("tool_name", "")
|
||||
if tool_name != "delegate_to_sub_agent":
|
||||
return
|
||||
if not exec_id:
|
||||
return
|
||||
now = _time.monotonic()
|
||||
if now - _last_digest.get(exec_id, 0.0) < _DIGEST_COOLDOWN:
|
||||
return
|
||||
run_id = _resolve_run_id(exec_id)
|
||||
if run_id:
|
||||
_last_digest[exec_id] = now
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, None),
|
||||
name=f"worker-digest-{run_id}",
|
||||
)
|
||||
|
||||
session.worker_digest_sub = session.event_bus.subscribe(
|
||||
event_types=[
|
||||
_ET.EXECUTION_STARTED,
|
||||
_ET.NODE_LOOP_ITERATION,
|
||||
_ET.TOOL_CALL_COMPLETED,
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
_ET.EXECUTION_FAILED,
|
||||
_ET.EXECUTION_PAUSED,
|
||||
],
|
||||
handler=_on_worker_event,
|
||||
)
|
||||
|
||||
def _subscribe_worker_handoffs(self, session: Session, executor: Any) -> None:
|
||||
"""Subscribe queen to worker/subagent escalation handoff events."""
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
@@ -700,16 +899,21 @@ class SessionManager:
|
||||
else None
|
||||
)
|
||||
)
|
||||
_meta_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agent_name": _agent_name,
|
||||
"agent_path": str(session.worker_path) if session.worker_path else None,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
# Merge into existing meta.json to preserve fields written by
|
||||
# _update_meta_json (e.g. phase, agent_path set during building).
|
||||
_existing_meta: dict = {}
|
||||
if _meta_path.exists():
|
||||
try:
|
||||
_existing_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
_new_meta: dict = {"created_at": time.time()}
|
||||
if _agent_name is not None:
|
||||
_new_meta["agent_name"] = _agent_name
|
||||
if session.worker_path is not None:
|
||||
_new_meta["agent_path"] = str(session.worker_path)
|
||||
_existing_meta.update(_new_meta)
|
||||
_meta_path.write_text(json.dumps(_existing_meta), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -762,11 +966,27 @@ class SessionManager:
|
||||
try:
|
||||
_meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
_agent_path = _meta.get("agent_path")
|
||||
_phase = _meta.get("phase")
|
||||
|
||||
if _agent_path and Path(_agent_path).exists():
|
||||
await self.load_worker(session.id, _agent_path)
|
||||
if session.phase_state:
|
||||
await session.phase_state.switch_to_staging(source="auto")
|
||||
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
|
||||
if _phase in ("staging", "running", None):
|
||||
# Agent fully built — load worker and resume
|
||||
await self.load_worker(session.id, _agent_path)
|
||||
if session.phase_state:
|
||||
await session.phase_state.switch_to_staging(source="auto")
|
||||
# Emit flowchart overlay so frontend can display it
|
||||
await self._emit_flowchart_on_restore(session, _agent_path)
|
||||
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
|
||||
elif _phase == "building":
|
||||
# Agent folder exists but incomplete — resume building
|
||||
if session.phase_state:
|
||||
session.phase_state.agent_path = _agent_path
|
||||
await session.phase_state.switch_to_building(source="auto")
|
||||
logger.info("Cold restore: resumed BUILDING phase for %s", _agent_path)
|
||||
elif _phase == "planning":
|
||||
if session.phase_state:
|
||||
session.phase_state.agent_path = _agent_path
|
||||
logger.info("Cold restore: PLANNING phase for %s", _agent_path)
|
||||
except Exception:
|
||||
logger.warning("Cold restore: failed to auto-load worker", exc_info=True)
|
||||
|
||||
@@ -841,6 +1061,29 @@ class SessionManager:
|
||||
)
|
||||
)
|
||||
|
||||
async def _emit_flowchart_on_restore(self, session: Session, agent_path: str | Path) -> None:
|
||||
"""Emit FLOWCHART_MAP_UPDATED from persisted flowchart file on cold restore."""
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.tools.flowchart_utils import load_flowchart_file
|
||||
|
||||
original_draft, flowchart_map = load_flowchart_file(agent_path)
|
||||
if original_draft is None:
|
||||
return
|
||||
# Cache in phase_state so the REST endpoint also returns it
|
||||
if session.phase_state:
|
||||
session.phase_state.original_draft_graph = original_draft
|
||||
session.phase_state.flowchart_map = flowchart_map
|
||||
await session.event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.FLOWCHART_MAP_UPDATED,
|
||||
stream_id="queen",
|
||||
data={
|
||||
"map": flowchart_map,
|
||||
"original_draft": original_draft,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def _notify_queen_worker_unloaded(self, session: Session) -> None:
|
||||
"""Notify the queen that the worker has been unloaded."""
|
||||
executor = session.queen_executor
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Hive Agent Skills — discovery, parsing, and injection of SKILL.md packages.
|
||||
|
||||
Implements the open Agent Skills standard (agentskills.io) for portable
|
||||
skill discovery and activation, plus built-in default skills for runtime
|
||||
operational discipline.
|
||||
"""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
|
||||
__all__ = [
|
||||
"DefaultSkillConfig",
|
||||
"DefaultSkillManager",
|
||||
"DiscoveryConfig",
|
||||
"ParsedSkill",
|
||||
"SkillCatalog",
|
||||
"SkillDiscovery",
|
||||
"SkillsConfig",
|
||||
"SkillsManager",
|
||||
"SkillsManagerConfig",
|
||||
"parse_skill_md",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
---
|
||||
name: hive.batch-ledger
|
||||
description: Track per-item status when processing collections to prevent skipped or duplicated items.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Batch Progress Ledger
|
||||
|
||||
When processing a collection of items, maintain a batch ledger in `_batch_ledger`.
|
||||
|
||||
Initialize when you identify the batch:
|
||||
- `_batch_total`: total item count
|
||||
- `_batch_ledger`: JSON with per-item status
|
||||
|
||||
Per-item statuses: pending → in_progress → completed|failed|skipped
|
||||
|
||||
- Set `in_progress` BEFORE processing
|
||||
- Set final status AFTER processing with 1-line result_summary
|
||||
- Include error reason for failed/skipped items
|
||||
- Update aggregate counts after each item
|
||||
- NEVER remove items from the ledger
|
||||
- If resuming, skip items already marked completed
|
||||
@@ -0,0 +1,22 @@
|
||||
---
|
||||
name: hive.context-preservation
|
||||
description: Proactively preserve critical information before automatic context pruning destroys it.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Context Preservation
|
||||
|
||||
You operate under a finite context window. Important information WILL be pruned.
|
||||
|
||||
Save-As-You-Go: After any tool call producing information you'll need later,
|
||||
immediately extract key data into `_working_notes` or `_preserved_data`.
|
||||
Do NOT rely on referring back to old tool results.
|
||||
|
||||
What to extract: URLs and key snippets (not full pages), relevant API fields
|
||||
(not raw JSON), specific lines/values (not entire files), analysis results
|
||||
(not raw data).
|
||||
|
||||
Before transitioning to the next phase/node, write a handoff summary to
|
||||
`_handoff_context` with everything the next phase needs to know.
|
||||
@@ -0,0 +1,18 @@
|
||||
---
|
||||
name: hive.error-recovery
|
||||
description: Follow a structured recovery protocol when tool calls fail instead of blindly retrying or giving up.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Error Recovery
|
||||
|
||||
When a tool call fails:
|
||||
|
||||
1. Diagnose — record error in notes, classify as transient or structural
|
||||
2. Decide — transient: retry once. Structural fixable: fix and retry.
|
||||
Structural unfixable: record as failed, move to next item.
|
||||
Blocking all progress: record escalation note.
|
||||
3. Adapt — if same tool failed 3+ times, stop using it and find alternative.
|
||||
Update plan in notes. Never silently drop the failed item.
|
||||
@@ -0,0 +1,27 @@
|
||||
---
|
||||
name: hive.note-taking
|
||||
description: Maintain structured working notes throughout execution to prevent information loss during context pruning.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Structured Note-Taking
|
||||
|
||||
Maintain structured working notes in shared memory key `_working_notes`.
|
||||
Update at these checkpoints:
|
||||
|
||||
- After completing each discrete subtask or batch item
|
||||
- After receiving new information that changes your plan
|
||||
- Before any tool call that will produce substantial output
|
||||
|
||||
Structure:
|
||||
|
||||
### Objective — restate the goal
|
||||
### Current Plan — numbered steps, mark completed with ✓
|
||||
### Key Decisions — decisions made and WHY
|
||||
### Working Data — intermediate results, extracted values
|
||||
### Open Questions — uncertainties to verify
|
||||
### Blockers — anything preventing progress
|
||||
|
||||
Update incrementally — do not rewrite from scratch each time.
|
||||
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: hive.quality-monitor
|
||||
description: Periodically self-assess output quality to catch degradation before the judge does.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Quality Self-Assessment
|
||||
|
||||
Every 5 iterations, self-assess:
|
||||
|
||||
1. On-task? Still working toward the stated objective?
|
||||
2. Thorough? Cutting corners compared to earlier?
|
||||
3. Non-repetitive? Producing new value or rehashing?
|
||||
4. Consistent? Latest output contradict earlier decisions?
|
||||
5. Complete? Tracking all items, or silently dropped some?
|
||||
|
||||
If degrading: write assessment to `_quality_log`, re-read `_working_notes`,
|
||||
change approach explicitly. If acceptable: brief note in `_quality_log`.
|
||||
@@ -0,0 +1,17 @@
|
||||
---
|
||||
name: hive.task-decomposition
|
||||
description: Decompose complex tasks into explicit subtasks before diving in.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Task Decomposition
|
||||
|
||||
Before starting a complex task:
|
||||
|
||||
1. Decompose — break into numbered subtasks in `_working_notes` Current Plan
|
||||
2. Estimate — relative effort per subtask (small/medium/large)
|
||||
3. Execute — work through in order, mark ✓ when complete
|
||||
4. Budget — if running low on iterations, prioritize by impact
|
||||
5. Verify — before declaring done, every subtask must be ✓, skipped (with reason), or blocked
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Skill catalog — in-memory index with system prompt generation.
|
||||
|
||||
Builds the XML catalog injected into the system prompt for model-driven
|
||||
skill activation per the Agent Skills standard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BEHAVIORAL_INSTRUCTION = (
|
||||
"The following skills provide specialized instructions for specific tasks.\n"
|
||||
"When a task matches a skill's description, read the SKILL.md at the listed\n"
|
||||
"location to load the full instructions before proceeding.\n"
|
||||
"When a skill references relative paths, resolve them against the skill's\n"
|
||||
"directory (the parent of SKILL.md) and use absolute paths in tool calls."
|
||||
)
|
||||
|
||||
|
||||
class SkillCatalog:
|
||||
"""In-memory catalog of discovered skills."""
|
||||
|
||||
def __init__(self, skills: list[ParsedSkill] | None = None):
|
||||
self._skills: dict[str, ParsedSkill] = {}
|
||||
self._activated: set[str] = set()
|
||||
if skills:
|
||||
for skill in skills:
|
||||
self.add(skill)
|
||||
|
||||
def add(self, skill: ParsedSkill) -> None:
|
||||
"""Add a skill to the catalog."""
|
||||
self._skills[skill.name] = skill
|
||||
|
||||
def get(self, name: str) -> ParsedSkill | None:
|
||||
"""Look up a skill by name."""
|
||||
return self._skills.get(name)
|
||||
|
||||
def mark_activated(self, name: str) -> None:
|
||||
"""Mark a skill as activated in the current session."""
|
||||
self._activated.add(name)
|
||||
|
||||
def is_activated(self, name: str) -> bool:
|
||||
"""Check if a skill has been activated."""
|
||||
return name in self._activated
|
||||
|
||||
@property
|
||||
def skill_count(self) -> int:
|
||||
return len(self._skills)
|
||||
|
||||
@property
|
||||
def allowlisted_dirs(self) -> list[str]:
|
||||
"""All skill base directories for file access allowlisting."""
|
||||
return [skill.base_dir for skill in self._skills.values()]
|
||||
|
||||
def to_prompt(self) -> str:
|
||||
"""Generate the catalog prompt for system prompt injection.
|
||||
|
||||
Returns empty string if no community/user skills are discovered
|
||||
(default skills are handled separately by DefaultSkillManager).
|
||||
"""
|
||||
# Filter out framework-scope skills (default skills) — they're
|
||||
# injected via the protocols prompt, not the catalog
|
||||
community_skills = [s for s in self._skills.values() if s.source_scope != "framework"]
|
||||
|
||||
if not community_skills:
|
||||
return ""
|
||||
|
||||
lines = ["<available_skills>"]
|
||||
for skill in sorted(community_skills, key=lambda s: s.name):
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{escape(skill.name)}</name>")
|
||||
lines.append(f" <description>{escape(skill.description)}</description>")
|
||||
lines.append(f" <location>{escape(skill.location)}</location>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</available_skills>")
|
||||
|
||||
xml_block = "\n".join(lines)
|
||||
return f"{_BEHAVIORAL_INSTRUCTION}\n\n{xml_block}"
|
||||
|
||||
def build_pre_activated_prompt(self, skill_names: list[str]) -> str:
|
||||
"""Build prompt content for pre-activated skills.
|
||||
|
||||
Pre-activated skills get their full SKILL.md body loaded into
|
||||
the system prompt at startup (tier 2), bypassing model-driven
|
||||
activation.
|
||||
|
||||
Returns empty string if no skills match.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
for name in skill_names:
|
||||
skill = self.get(name)
|
||||
if skill is None:
|
||||
logger.warning("Pre-activated skill '%s' not found in catalog", name)
|
||||
continue
|
||||
if self.is_activated(name):
|
||||
continue # Already activated, skip duplicate
|
||||
|
||||
self.mark_activated(name)
|
||||
parts.append(f"--- Pre-Activated Skill: {skill.name} ---\n{skill.body}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Skill configuration dataclasses.
|
||||
|
||||
Handles agent-level skill configuration from module-level variables
|
||||
(``default_skills`` and ``skills``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultSkillConfig:
|
||||
"""Configuration for a single default skill."""
|
||||
|
||||
enabled: bool = True
|
||||
overrides: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DefaultSkillConfig:
|
||||
enabled = data.get("enabled", True)
|
||||
overrides = {k: v for k, v in data.items() if k != "enabled"}
|
||||
return cls(enabled=enabled, overrides=overrides)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillsConfig:
|
||||
"""Agent-level skill configuration.
|
||||
|
||||
Built from module-level variables in agent.py::
|
||||
|
||||
# Pre-activated community skills
|
||||
skills = ["deep-research", "code-review"]
|
||||
|
||||
# Default skill configuration
|
||||
default_skills = {
|
||||
"hive.note-taking": {"enabled": True},
|
||||
"hive.batch-ledger": {"enabled": True, "checkpoint_every_n": 10},
|
||||
"hive.quality-monitor": {"enabled": False},
|
||||
}
|
||||
"""
|
||||
|
||||
# Per-default-skill config, keyed by skill name (e.g. "hive.note-taking")
|
||||
default_skills: dict[str, DefaultSkillConfig] = field(default_factory=dict)
|
||||
|
||||
# Pre-activated community skills (by name)
|
||||
skills: list[str] = field(default_factory=list)
|
||||
|
||||
# Master switch: disable all default skills at once
|
||||
all_defaults_disabled: bool = False
|
||||
|
||||
def is_default_enabled(self, skill_name: str) -> bool:
|
||||
"""Check if a specific default skill is enabled."""
|
||||
if self.all_defaults_disabled:
|
||||
return False
|
||||
config = self.default_skills.get(skill_name)
|
||||
if config is None:
|
||||
return True # enabled by default
|
||||
return config.enabled
|
||||
|
||||
def get_default_overrides(self, skill_name: str) -> dict[str, Any]:
|
||||
"""Get skill-specific configuration overrides."""
|
||||
config = self.default_skills.get(skill_name)
|
||||
if config is None:
|
||||
return {}
|
||||
return config.overrides
|
||||
|
||||
@classmethod
|
||||
def from_agent_vars(
|
||||
cls,
|
||||
default_skills: dict[str, Any] | None = None,
|
||||
skills: list[str] | None = None,
|
||||
) -> SkillsConfig:
|
||||
"""Build config from agent module-level variables.
|
||||
|
||||
Args:
|
||||
default_skills: Dict from agent module, e.g.
|
||||
``{"hive.note-taking": {"enabled": True}}``
|
||||
skills: List of pre-activated skill names from agent module
|
||||
"""
|
||||
all_disabled = False
|
||||
parsed_defaults: dict[str, DefaultSkillConfig] = {}
|
||||
|
||||
if default_skills:
|
||||
for name, config_dict in default_skills.items():
|
||||
if name == "_all":
|
||||
if isinstance(config_dict, dict) and not config_dict.get("enabled", True):
|
||||
all_disabled = True
|
||||
continue
|
||||
if isinstance(config_dict, dict):
|
||||
parsed_defaults[name] = DefaultSkillConfig.from_dict(config_dict)
|
||||
elif isinstance(config_dict, bool):
|
||||
parsed_defaults[name] = DefaultSkillConfig(enabled=config_dict)
|
||||
|
||||
return cls(
|
||||
default_skills=parsed_defaults,
|
||||
skills=list(skills or []),
|
||||
all_defaults_disabled=all_disabled,
|
||||
)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""DefaultSkillManager — load, configure, and inject built-in default skills.
|
||||
|
||||
Default skills are SKILL.md packages shipped with the framework that provide
|
||||
runtime operational protocols (note-taking, batch tracking, error recovery, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default skills directory relative to this module
|
||||
_DEFAULT_SKILLS_DIR = Path(__file__).parent / "_default_skills"
|
||||
|
||||
# Ordered list of default skills (name → directory)
|
||||
SKILL_REGISTRY: dict[str, str] = {
|
||||
"hive.note-taking": "note-taking",
|
||||
"hive.batch-ledger": "batch-ledger",
|
||||
"hive.context-preservation": "context-preservation",
|
||||
"hive.quality-monitor": "quality-monitor",
|
||||
"hive.error-recovery": "error-recovery",
|
||||
"hive.task-decomposition": "task-decomposition",
|
||||
}
|
||||
|
||||
# All shared memory keys used by default skills (for permission auto-inclusion)
|
||||
SHARED_MEMORY_KEYS: list[str] = [
|
||||
# note-taking
|
||||
"_working_notes",
|
||||
"_notes_updated_at",
|
||||
# batch-ledger
|
||||
"_batch_ledger",
|
||||
"_batch_total",
|
||||
"_batch_completed",
|
||||
"_batch_failed",
|
||||
# context-preservation
|
||||
"_handoff_context",
|
||||
"_preserved_data",
|
||||
# quality-monitor
|
||||
"_quality_log",
|
||||
"_quality_degradation_count",
|
||||
# error-recovery
|
||||
"_error_log",
|
||||
"_failed_tools",
|
||||
"_escalation_needed",
|
||||
# task-decomposition
|
||||
"_subtasks",
|
||||
"_iteration_budget_remaining",
|
||||
]
|
||||
|
||||
|
||||
class DefaultSkillManager:
|
||||
"""Manages loading, configuration, and prompt generation for default skills."""
|
||||
|
||||
def __init__(self, config: SkillsConfig | None = None):
|
||||
self._config = config or SkillsConfig()
|
||||
self._skills: dict[str, ParsedSkill] = {}
|
||||
self._loaded = False
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load all enabled default skill SKILL.md files."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
for skill_name, dir_name in SKILL_REGISTRY.items():
|
||||
if not self._config.is_default_enabled(skill_name):
|
||||
logger.info("Default skill '%s' disabled by config", skill_name)
|
||||
continue
|
||||
|
||||
skill_path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
if not skill_path.is_file():
|
||||
logger.error("Default skill SKILL.md not found: %s", skill_path)
|
||||
continue
|
||||
|
||||
parsed = parse_skill_md(skill_path, source_scope="framework")
|
||||
if parsed is None:
|
||||
logger.error("Failed to parse default skill: %s", skill_path)
|
||||
continue
|
||||
|
||||
self._skills[skill_name] = parsed
|
||||
|
||||
self._loaded = True
|
||||
|
||||
def build_protocols_prompt(self) -> str:
|
||||
"""Build the combined operational protocols section.
|
||||
|
||||
Extracts protocol sections from all enabled default skills and
|
||||
combines them into a single ``## Operational Protocols`` block
|
||||
for system prompt injection.
|
||||
|
||||
Returns empty string if all defaults are disabled.
|
||||
"""
|
||||
if not self._skills:
|
||||
return ""
|
||||
|
||||
parts: list[str] = ["## Operational Protocols\n"]
|
||||
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill is None:
|
||||
continue
|
||||
# Use the full body — each SKILL.md contains exactly one protocol section
|
||||
parts.append(skill.body)
|
||||
|
||||
if len(parts) <= 1:
|
||||
return ""
|
||||
|
||||
combined = "\n\n".join(parts)
|
||||
|
||||
# Token budget warning (approximate: 1 token ≈ 4 chars)
|
||||
approx_tokens = len(combined) // 4
|
||||
if approx_tokens > 2000:
|
||||
logger.warning(
|
||||
"Default skill protocols exceed 2000 token budget "
|
||||
"(~%d tokens, %d chars). Consider trimming.",
|
||||
approx_tokens,
|
||||
len(combined),
|
||||
)
|
||||
|
||||
return combined
|
||||
|
||||
def log_active_skills(self) -> None:
|
||||
"""Log which default skills are active and their configuration."""
|
||||
if not self._skills:
|
||||
logger.info("Default skills: all disabled")
|
||||
return
|
||||
|
||||
active = []
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
overrides = self._config.get_default_overrides(skill_name)
|
||||
if overrides:
|
||||
active.append(f"{skill_name} ({overrides})")
|
||||
else:
|
||||
active.append(skill_name)
|
||||
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
|
||||
@property
|
||||
def active_skill_names(self) -> list[str]:
|
||||
"""Names of all currently active default skills."""
|
||||
return list(self._skills.keys())
|
||||
|
||||
@property
|
||||
def active_skills(self) -> dict[str, ParsedSkill]:
|
||||
"""All active default skills keyed by name."""
|
||||
return dict(self._skills)
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Skill discovery — scan standard directories for SKILL.md files.
|
||||
|
||||
Implements the Agent Skills standard discovery paths plus Hive-specific
|
||||
locations. Resolves name collisions deterministically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Directories to skip during scanning
|
||||
_SKIP_DIRS = frozenset(
|
||||
{
|
||||
".git",
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".mypy_cache",
|
||||
".pytest_cache",
|
||||
".ruff_cache",
|
||||
}
|
||||
)
|
||||
|
||||
# Scope priority (higher = takes precedence)
|
||||
_SCOPE_PRIORITY = {
|
||||
"framework": 0,
|
||||
"user": 1,
|
||||
"project": 2,
|
||||
}
|
||||
|
||||
# Within the same scope, Hive-specific paths override cross-client paths.
|
||||
# We encode this by scanning cross-client first, then Hive-specific (later wins).
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveryConfig:
|
||||
"""Configuration for skill discovery."""
|
||||
|
||||
project_root: Path | None = None
|
||||
skip_user_scope: bool = False
|
||||
skip_framework_scope: bool = False
|
||||
max_depth: int = 4
|
||||
max_dirs: int = 2000
|
||||
|
||||
|
||||
class SkillDiscovery:
|
||||
"""Scans standard directories for SKILL.md files and resolves collisions."""
|
||||
|
||||
def __init__(self, config: DiscoveryConfig | None = None):
|
||||
self._config = config or DiscoveryConfig()
|
||||
|
||||
def discover(self) -> list[ParsedSkill]:
|
||||
"""Scan all scopes and return deduplicated skill list.
|
||||
|
||||
Scanning order (lowest to highest precedence):
|
||||
1. Framework defaults
|
||||
2. User cross-client (~/.agents/skills/)
|
||||
3. User Hive-specific (~/.hive/skills/)
|
||||
4. Project cross-client (<project>/.agents/skills/)
|
||||
5. Project Hive-specific (<project>/.hive/skills/)
|
||||
|
||||
Later entries override earlier ones on name collision.
|
||||
"""
|
||||
all_skills: list[ParsedSkill] = []
|
||||
|
||||
# Framework scope (lowest precedence)
|
||||
if not self._config.skip_framework_scope:
|
||||
framework_dir = Path(__file__).parent / "_default_skills"
|
||||
if framework_dir.is_dir():
|
||||
all_skills.extend(self._scan_scope(framework_dir, "framework"))
|
||||
|
||||
# User scope
|
||||
if not self._config.skip_user_scope:
|
||||
home = Path.home()
|
||||
|
||||
# Cross-client (lower precedence within user scope)
|
||||
user_agents = home / ".agents" / "skills"
|
||||
if user_agents.is_dir():
|
||||
all_skills.extend(self._scan_scope(user_agents, "user"))
|
||||
|
||||
# Hive-specific (higher precedence within user scope)
|
||||
user_hive = home / ".hive" / "skills"
|
||||
if user_hive.is_dir():
|
||||
all_skills.extend(self._scan_scope(user_hive, "user"))
|
||||
|
||||
# Project scope (highest precedence)
|
||||
if self._config.project_root:
|
||||
root = self._config.project_root
|
||||
|
||||
# Cross-client
|
||||
project_agents = root / ".agents" / "skills"
|
||||
if project_agents.is_dir():
|
||||
all_skills.extend(self._scan_scope(project_agents, "project"))
|
||||
|
||||
# Hive-specific
|
||||
project_hive = root / ".hive" / "skills"
|
||||
if project_hive.is_dir():
|
||||
all_skills.extend(self._scan_scope(project_hive, "project"))
|
||||
|
||||
resolved = self._resolve_collisions(all_skills)
|
||||
|
||||
logger.info(
|
||||
"Skill discovery: found %d skills (%d after dedup) across all scopes",
|
||||
len(all_skills),
|
||||
len(resolved),
|
||||
)
|
||||
return resolved
|
||||
|
||||
def _scan_scope(self, root: Path, scope: str) -> list[ParsedSkill]:
|
||||
"""Scan a single directory for skill directories containing SKILL.md."""
|
||||
skills: list[ParsedSkill] = []
|
||||
dirs_scanned = 0
|
||||
|
||||
for skill_md in self._find_skill_files(root, depth=0):
|
||||
if dirs_scanned >= self._config.max_dirs:
|
||||
logger.warning(
|
||||
"Hit max directory limit (%d) scanning %s",
|
||||
self._config.max_dirs,
|
||||
root,
|
||||
)
|
||||
break
|
||||
|
||||
parsed = parse_skill_md(skill_md, source_scope=scope)
|
||||
if parsed is not None:
|
||||
skills.append(parsed)
|
||||
dirs_scanned += 1
|
||||
|
||||
return skills
|
||||
|
||||
def _find_skill_files(self, directory: Path, depth: int) -> list[Path]:
|
||||
"""Recursively find SKILL.md files up to max_depth."""
|
||||
if depth > self._config.max_depth:
|
||||
return []
|
||||
|
||||
results: list[Path] = []
|
||||
|
||||
try:
|
||||
entries = sorted(directory.iterdir())
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
for entry in entries:
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
if entry.name in _SKIP_DIRS:
|
||||
continue
|
||||
|
||||
skill_md = entry / "SKILL.md"
|
||||
if skill_md.is_file():
|
||||
results.append(skill_md)
|
||||
else:
|
||||
# Recurse into subdirectories
|
||||
results.extend(self._find_skill_files(entry, depth + 1))
|
||||
|
||||
return results
|
||||
|
||||
def _resolve_collisions(self, skills: list[ParsedSkill]) -> list[ParsedSkill]:
|
||||
"""Resolve name collisions deterministically.
|
||||
|
||||
Later entries in the list override earlier ones (because we scan
|
||||
from lowest to highest precedence). On collision, log a warning.
|
||||
"""
|
||||
seen: dict[str, ParsedSkill] = {}
|
||||
|
||||
for skill in skills:
|
||||
if skill.name in seen:
|
||||
existing = seen[skill.name]
|
||||
logger.warning(
|
||||
"Skill name collision: '%s' from %s overrides %s",
|
||||
skill.name,
|
||||
skill.location,
|
||||
existing.location,
|
||||
)
|
||||
seen[skill.name] = skill
|
||||
|
||||
return list(seen.values())
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Unified skill lifecycle manager.
|
||||
|
||||
``SkillsManager`` is the single facade that owns skill discovery, loading,
|
||||
and prompt renderation. The runtime creates one at startup and downstream
|
||||
layers read the cached prompt strings.
|
||||
|
||||
Typical usage — **config-driven** (runner passes configuration)::
|
||||
|
||||
config = SkillsManagerConfig(
|
||||
skills_config=SkillsConfig.from_agent_vars(...),
|
||||
project_root=agent_path,
|
||||
)
|
||||
mgr = SkillsManager(config)
|
||||
mgr.load()
|
||||
print(mgr.protocols_prompt) # default skill protocols
|
||||
print(mgr.skills_catalog_prompt) # community skills XML
|
||||
|
||||
Typical usage — **bare** (exported agents, SDK users)::
|
||||
|
||||
mgr = SkillsManager() # default config
|
||||
mgr.load() # loads all 6 default skills, no community discovery
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillsManagerConfig:
|
||||
"""Everything the runtime needs to configure skills.
|
||||
|
||||
Attributes:
|
||||
skills_config: Per-skill enable/disable and overrides.
|
||||
project_root: Agent directory for community skill discovery.
|
||||
When ``None``, community discovery is skipped.
|
||||
skip_community_discovery: Explicitly skip community scanning
|
||||
even when ``project_root`` is set.
|
||||
"""
|
||||
|
||||
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
|
||||
project_root: Path | None = None
|
||||
skip_community_discovery: bool = False
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
"""Unified skill lifecycle: discovery → loading → prompt renderation.
|
||||
|
||||
The runtime creates one instance during init and owns it for the
|
||||
lifetime of the process. Downstream layers (``ExecutionStream``,
|
||||
``GraphExecutor``, ``NodeContext``, ``EventLoopNode``) receive the
|
||||
cached prompt strings via property accessors.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SkillsManagerConfig | None = None) -> None:
|
||||
self._config = config or SkillsManagerConfig()
|
||||
self._loaded = False
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_precomputed(
|
||||
cls,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
) -> SkillsManager:
|
||||
"""Wrap pre-rendered prompt strings (legacy callers).
|
||||
|
||||
Returns a manager that skips discovery/loading and just returns
|
||||
the provided strings. Used by the deprecation bridge in
|
||||
``AgentRuntime`` when callers pass raw prompt strings.
|
||||
"""
|
||||
mgr = cls.__new__(cls)
|
||||
mgr._config = SkillsManagerConfig()
|
||||
mgr._loaded = True # skip load()
|
||||
mgr._catalog_prompt = skills_catalog_prompt
|
||||
mgr._protocols_prompt = protocols_prompt
|
||||
return mgr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load(self) -> None:
|
||||
"""Discover, load, and cache skill prompts. Idempotent."""
|
||||
if self._loaded:
|
||||
return
|
||||
self._loaded = True
|
||||
|
||||
try:
|
||||
self._do_load()
|
||||
except Exception:
|
||||
logger.warning("Skill system init failed (non-fatal)", exc_info=True)
|
||||
|
||||
def _do_load(self) -> None:
|
||||
"""Internal load — may raise; caller catches."""
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
|
||||
skills_config = self._config.skills_config
|
||||
|
||||
# 1. Community skill discovery (when project_root is available)
|
||||
catalog_prompt = ""
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
discovery = SkillDiscovery(DiscoveryConfig(project_root=self._config.project_root))
|
||||
discovered = discovery.discover()
|
||||
catalog = SkillCatalog(discovered)
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
|
||||
# Pre-activated community skills
|
||||
if skills_config.skills:
|
||||
pre_activated = catalog.build_pre_activated_prompt(skills_config.skills)
|
||||
if pre_activated:
|
||||
if catalog_prompt:
|
||||
catalog_prompt = f"{catalog_prompt}\n\n{pre_activated}"
|
||||
else:
|
||||
catalog_prompt = pre_activated
|
||||
|
||||
# 2. Default skills (always loaded unless explicitly disabled)
|
||||
default_mgr = DefaultSkillManager(config=skills_config)
|
||||
default_mgr.load()
|
||||
default_mgr.log_active_skills()
|
||||
protocols_prompt = default_mgr.build_protocols_prompt()
|
||||
|
||||
# 3. Cache
|
||||
self._catalog_prompt = catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
|
||||
if protocols_prompt:
|
||||
logger.info(
|
||||
"Skill system ready: protocols=%d chars, catalog=%d chars",
|
||||
len(protocols_prompt),
|
||||
len(catalog_prompt),
|
||||
)
|
||||
else:
|
||||
logger.warning("Skill system produced empty protocols_prompt")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt accessors (consumed by downstream layers)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def skills_catalog_prompt(self) -> str:
|
||||
"""Community skills XML catalog for system prompt injection."""
|
||||
return self._catalog_prompt
|
||||
|
||||
@property
|
||||
def protocols_prompt(self) -> str:
|
||||
"""Default skill operational protocols for system prompt injection."""
|
||||
return self._protocols_prompt
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
@@ -0,0 +1,158 @@
|
||||
"""SKILL.md parser — extracts YAML frontmatter and markdown body.
|
||||
|
||||
Parses SKILL.md files per the Agent Skills standard (agentskills.io/specification).
|
||||
Lenient validation: warns on non-critical issues, skips only on missing description
|
||||
or completely unparseable YAML.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum name length before a warning is logged
|
||||
_MAX_NAME_LENGTH = 64
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedSkill:
|
||||
"""In-memory representation of a parsed SKILL.md file."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
location: str # absolute path to SKILL.md
|
||||
base_dir: str # parent directory of SKILL.md
|
||||
source_scope: str # "project", "user", or "framework"
|
||||
body: str # markdown body after closing ---
|
||||
|
||||
# Optional frontmatter fields
|
||||
license: str | None = None
|
||||
compatibility: list[str] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
allowed_tools: list[str] | None = None
|
||||
|
||||
|
||||
def _try_fix_yaml(raw: str) -> str:
|
||||
"""Attempt to fix common YAML issues (unquoted colon values).
|
||||
|
||||
Some SKILL.md files written for other clients may contain unquoted
|
||||
values with colons, e.g. ``description: Use for: research tasks``.
|
||||
This wraps such values in quotes as a best-effort fixup.
|
||||
"""
|
||||
lines = raw.split("\n")
|
||||
fixed = []
|
||||
for line in lines:
|
||||
# Match "key: value" where value contains an unquoted colon
|
||||
m = re.match(r"^(\s*\w[\w-]*:\s*)(.+)$", line)
|
||||
if m:
|
||||
key_part, value_part = m.group(1), m.group(2)
|
||||
# If value contains a colon and isn't already quoted
|
||||
if ":" in value_part and not (value_part.startswith('"') or value_part.startswith("'")):
|
||||
value_part = f'"{value_part}"'
|
||||
fixed.append(f"{key_part}{value_part}")
|
||||
else:
|
||||
fixed.append(line)
|
||||
return "\n".join(fixed)
|
||||
|
||||
|
||||
def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | None:
|
||||
"""Parse a SKILL.md file into a ParsedSkill record.
|
||||
|
||||
Args:
|
||||
path: Absolute path to the SKILL.md file.
|
||||
source_scope: One of "project", "user", or "framework".
|
||||
|
||||
Returns:
|
||||
ParsedSkill on success, None if the file is unparseable or
|
||||
missing required fields (description).
|
||||
"""
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.error("Failed to read %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
if not content.strip():
|
||||
logger.error("Empty SKILL.md: %s", path)
|
||||
return None
|
||||
|
||||
# Split on --- delimiters (first two occurrences)
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
logger.error("SKILL.md missing YAML frontmatter delimiters (---): %s", path)
|
||||
return None
|
||||
|
||||
# parts[0] is content before first --- (should be empty or whitespace)
|
||||
# parts[1] is the YAML frontmatter
|
||||
# parts[2] is the markdown body
|
||||
raw_yaml = parts[1].strip()
|
||||
body = parts[2].strip()
|
||||
|
||||
if not raw_yaml:
|
||||
logger.error("Empty YAML frontmatter in %s", path)
|
||||
return None
|
||||
|
||||
# Parse YAML
|
||||
import yaml
|
||||
|
||||
frontmatter: dict[str, Any] | None = None
|
||||
try:
|
||||
frontmatter = yaml.safe_load(raw_yaml)
|
||||
except yaml.YAMLError:
|
||||
# Fallback: try fixing unquoted colon values
|
||||
try:
|
||||
fixed = _try_fix_yaml(raw_yaml)
|
||||
frontmatter = yaml.safe_load(fixed)
|
||||
logger.warning("Fixed YAML parse issues in %s (unquoted colons)", path)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.error("Unparseable YAML in %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter, dict):
|
||||
logger.error("YAML frontmatter is not a mapping in %s", path)
|
||||
return None
|
||||
|
||||
# Required: description
|
||||
description = frontmatter.get("description")
|
||||
if not description or not str(description).strip():
|
||||
logger.error("Missing or empty 'description' in %s — skipping skill", path)
|
||||
return None
|
||||
|
||||
# Required: name (fallback to parent directory name)
|
||||
name = frontmatter.get("name")
|
||||
parent_dir_name = path.parent.name
|
||||
if not name or not str(name).strip():
|
||||
name = parent_dir_name
|
||||
logger.warning("Missing 'name' in %s — using directory name '%s'", path, name)
|
||||
else:
|
||||
name = str(name).strip()
|
||||
|
||||
# Lenient warnings
|
||||
if len(name) > _MAX_NAME_LENGTH:
|
||||
logger.warning("Skill name exceeds %d chars in %s: '%s'", _MAX_NAME_LENGTH, path, name)
|
||||
|
||||
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
|
||||
logger.warning(
|
||||
"Skill name '%s' doesn't match parent directory '%s' in %s",
|
||||
name,
|
||||
parent_dir_name,
|
||||
path,
|
||||
)
|
||||
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=str(description).strip(),
|
||||
location=str(path.resolve()),
|
||||
base_dir=str(path.parent.resolve()),
|
||||
source_scope=source_scope,
|
||||
body=body,
|
||||
license=frontmatter.get("license"),
|
||||
compatibility=frontmatter.get("compatibility"),
|
||||
metadata=frontmatter.get("metadata"),
|
||||
allowed_tools=frontmatter.get("allowed-tools"),
|
||||
)
|
||||
@@ -36,6 +36,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
@@ -115,6 +116,9 @@ class QueenPhaseState:
|
||||
prompt_staging: str = ""
|
||||
prompt_running: str = ""
|
||||
|
||||
# Default skill operational protocols — appended to every phase prompt
|
||||
protocols_prompt: str = ""
|
||||
|
||||
def get_current_tools(self) -> list:
|
||||
"""Return tools for the current phase."""
|
||||
if self.phase == "planning":
|
||||
@@ -139,7 +143,12 @@ class QueenPhaseState:
|
||||
from framework.agents.queen.queen_memory import format_for_injection
|
||||
|
||||
memory = format_for_injection()
|
||||
return base + ("\n\n" + memory if memory else "")
|
||||
parts = [base]
|
||||
if self.protocols_prompt:
|
||||
parts.append(self.protocols_prompt)
|
||||
if memory:
|
||||
parts.append(memory)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def _emit_phase_event(self) -> None:
|
||||
"""Publish a QUEEN_PHASE_CHANGED event so the frontend updates the tag."""
|
||||
@@ -399,10 +408,11 @@ async def _start_trigger_timer(session: Any, trigger_id: str, tdef: Any) -> None
|
||||
else:
|
||||
await asyncio.sleep(float(interval_minutes) * 60)
|
||||
|
||||
# Record next fire time for introspection
|
||||
# Record next fire time for introspection (monotonic, matches routes)
|
||||
fire_times = getattr(session, "trigger_next_fire", None)
|
||||
if fire_times is not None:
|
||||
fire_times[trigger_id] = datetime.now(tz=UTC).isoformat()
|
||||
_next_delay = float(interval_minutes) * 60 if interval_minutes else 60
|
||||
fire_times[trigger_id] = time.monotonic() + _next_delay
|
||||
|
||||
# Gate on worker being loaded
|
||||
if getattr(session, "worker_runtime", None) is None:
|
||||
@@ -717,6 +727,25 @@ def _dissolve_planning_nodes(
|
||||
return converted, flowchart_map
|
||||
|
||||
|
||||
def _update_meta_json(session_manager, manager_session_id, updates: dict) -> None:
|
||||
"""Merge updates into the queen session's meta.json."""
|
||||
if session_manager is None or not manager_session_id:
|
||||
return
|
||||
srv_session = session_manager.get_session(manager_session_id)
|
||||
if not srv_session:
|
||||
return
|
||||
storage_sid = getattr(srv_session, "queen_resume_from", None) or srv_session.id
|
||||
meta_path = Path.home() / ".hive" / "queen" / "session" / storage_sid / "meta.json"
|
||||
try:
|
||||
existing = {}
|
||||
if meta_path.exists():
|
||||
existing = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
existing.update(updates)
|
||||
meta_path.write_text(json.dumps(existing), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def register_queen_lifecycle_tools(
|
||||
registry: ToolRegistry,
|
||||
session: Any = None,
|
||||
@@ -965,6 +994,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to building phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_building()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "building"
|
||||
@@ -1549,12 +1579,22 @@ def register_queen_lifecycle_tools(
|
||||
# Find edges where this leaf node is the source
|
||||
out_edges = [e for e in validated_edges if e["source"] == leaf_id]
|
||||
in_edges = [e for e in validated_edges if e["target"] == leaf_id]
|
||||
if not out_edges:
|
||||
continue # already a proper leaf
|
||||
|
||||
# Identify the parent (predecessor that connects IN)
|
||||
parent_ids = [e["source"] for e in in_edges]
|
||||
|
||||
if not out_edges:
|
||||
# Already a proper leaf — still ensure sub_agents is set
|
||||
for pid in parent_ids:
|
||||
parent = node_by_id_v.get(pid)
|
||||
if parent is None:
|
||||
continue
|
||||
existing = parent.get("sub_agents") or []
|
||||
if leaf_id not in existing:
|
||||
existing.append(leaf_id)
|
||||
parent["sub_agents"] = existing
|
||||
continue
|
||||
|
||||
# Strip all outgoing edges from the leaf node that
|
||||
# don't go back to a parent (report edges are OK)
|
||||
illegal_targets: list[str] = []
|
||||
@@ -1968,6 +2008,17 @@ def register_queen_lifecycle_tools(
|
||||
"type": "string",
|
||||
"description": "What success looks like for this node",
|
||||
},
|
||||
"sub_agents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"IDs of GCU/browser sub-agent nodes managed by this node. "
|
||||
"At build time, sub-agent nodes are dissolved into this list. "
|
||||
"Set this on the PARENT node — e.g. the orchestrator that "
|
||||
"delegates to GCU leaves. Visual delegation edges are "
|
||||
"synthesized automatically."
|
||||
),
|
||||
},
|
||||
"decision_clause": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
@@ -2085,8 +2136,22 @@ def register_queen_lifecycle_tools(
|
||||
phase_state.draft_graph = converted
|
||||
phase_state.flowchart_map = fmap
|
||||
|
||||
# Note: flowchart file is persisted later, in initialize_and_build_agent
|
||||
# (after the agent folder is scaffolded) or in load_built_agent.
|
||||
# Create agent folder early so flowchart and agent_path are available
|
||||
# throughout the entire BUILDING phase.
|
||||
_agent_name = phase_state.draft_graph.get("agent_name", "").strip()
|
||||
if _agent_name:
|
||||
_agent_folder = Path("exports") / _agent_name
|
||||
_agent_folder.mkdir(parents=True, exist_ok=True)
|
||||
_save_flowchart_file(_agent_folder, original_copy, fmap)
|
||||
phase_state.agent_path = str(_agent_folder)
|
||||
_update_meta_json(
|
||||
session_manager,
|
||||
manager_session_id,
|
||||
{
|
||||
"agent_path": str(_agent_folder),
|
||||
"agent_name": _agent_name.replace("_", " ").title(),
|
||||
},
|
||||
)
|
||||
|
||||
dissolved_count = len(original_nodes) - len(converted.get("nodes", []))
|
||||
decision_count = sum(1 for n in original_nodes if n.get("flowchart_type") == "decision")
|
||||
@@ -2218,6 +2283,7 @@ def register_queen_lifecycle_tools(
|
||||
if fallback_path:
|
||||
phase_state.agent_path = str(fallback_path)
|
||||
await phase_state.switch_to_building(source="tool")
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
|
||||
if phase_state.inject_notification:
|
||||
await phase_state.inject_notification(
|
||||
"[PHASE CHANGE] Switched to BUILDING phase. "
|
||||
@@ -2260,8 +2326,13 @@ def register_queen_lifecycle_tools(
|
||||
if parsed.get("success", True):
|
||||
if phase_state is not None:
|
||||
# Set agent_path so the frontend can query credentials
|
||||
phase_state.agent_path = str(Path("exports") / agent_name)
|
||||
phase_state.agent_path = phase_state.agent_path or str(
|
||||
Path("exports") / agent_name
|
||||
)
|
||||
await phase_state.switch_to_building(source="tool")
|
||||
_update_meta_json(
|
||||
session_manager, manager_session_id, {"phase": "building"}
|
||||
)
|
||||
# Reset draft state after successful scaffolding
|
||||
phase_state.build_confirmed = False
|
||||
# Persist flowchart now that the agent folder exists
|
||||
@@ -2309,6 +2380,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to staging phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_staging()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "staging"
|
||||
@@ -2337,6 +2409,30 @@ def register_queen_lifecycle_tools(
|
||||
"""Get the session's event bus for querying history."""
|
||||
return getattr(session, "event_bus", None)
|
||||
|
||||
def _get_worker_name() -> str | None:
|
||||
"""Return the worker agent directory name, used for diary lookups."""
|
||||
p = getattr(session, "worker_path", None)
|
||||
return p.name if p else None
|
||||
|
||||
def _format_diary(max_runs: int) -> str:
|
||||
"""Read recent run digests from disk — no EventBus required."""
|
||||
agent_name = _get_worker_name()
|
||||
if not agent_name:
|
||||
return "No worker loaded — diary unavailable."
|
||||
from framework.agents.worker_memory import read_recent_digests
|
||||
|
||||
entries = read_recent_digests(agent_name, max_runs)
|
||||
if not entries:
|
||||
return (
|
||||
f"No run digests for '{agent_name}' yet. "
|
||||
"Digests are written at the end of each completed run."
|
||||
)
|
||||
lines = [f"Worker '{agent_name}' — {len(entries)} recent run digest(s):", ""]
|
||||
for _run_id, content in entries:
|
||||
lines.append(content)
|
||||
lines.append("")
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
# Tiered cooldowns: summary is free, detail has short cooldown, full keeps 30s
|
||||
_COOLDOWN_FULL = 30.0
|
||||
_COOLDOWN_DETAIL = 10.0
|
||||
@@ -2939,16 +3035,17 @@ def register_queen_lifecycle_tools(
|
||||
import time as _time
|
||||
|
||||
# --- Tiered cooldown ---
|
||||
# diary is free (file reads only), summary is free, detail has 10s, full has 30s
|
||||
now = _time.monotonic()
|
||||
if focus == "full":
|
||||
cooldown = _COOLDOWN_FULL
|
||||
tier = "full"
|
||||
elif focus is not None:
|
||||
elif focus == "diary" or focus is None:
|
||||
cooldown = 0.0
|
||||
tier = focus or "summary"
|
||||
else:
|
||||
cooldown = _COOLDOWN_DETAIL
|
||||
tier = "detail"
|
||||
else:
|
||||
cooldown = 0.0
|
||||
tier = "summary"
|
||||
|
||||
elapsed_since = now - _status_last_called.get(tier, 0.0)
|
||||
if elapsed_since < cooldown:
|
||||
@@ -2964,6 +3061,10 @@ def register_queen_lifecycle_tools(
|
||||
)
|
||||
_status_last_called[tier] = now
|
||||
|
||||
# --- Diary: pure file reads, no runtime required ---
|
||||
if focus == "diary":
|
||||
return _format_diary(last_n)
|
||||
|
||||
# --- Runtime check ---
|
||||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
@@ -3013,7 +3114,7 @@ def register_queen_lifecycle_tools(
|
||||
else:
|
||||
return (
|
||||
f"Unknown focus '{focus}'. "
|
||||
"Valid options: activity, memory, tools, issues, progress, full."
|
||||
"Valid options: diary, activity, memory, tools, issues, progress, full."
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("get_worker_status error")
|
||||
@@ -3024,6 +3125,8 @@ def register_queen_lifecycle_tools(
|
||||
description=(
|
||||
"Check on the worker. Returns a brief prose summary by default. "
|
||||
"Use 'focus' to drill into specifics:\n"
|
||||
"- diary: persistent run digests from past executions — read this first "
|
||||
"before digging into live runtime logs\n"
|
||||
"- activity: current node, transitions, latest LLM output\n"
|
||||
"- memory: worker's accumulated knowledge and state\n"
|
||||
"- tools: running and recent tool calls\n"
|
||||
@@ -3036,8 +3139,11 @@ def register_queen_lifecycle_tools(
|
||||
"properties": {
|
||||
"focus": {
|
||||
"type": "string",
|
||||
"enum": ["activity", "memory", "tools", "issues", "progress", "full"],
|
||||
"description": ("Aspect to inspect. Omit for a brief summary."),
|
||||
"enum": ["diary", "activity", "memory", "tools", "issues", "progress", "full"],
|
||||
"description": (
|
||||
"Aspect to inspect. Omit for a brief summary. "
|
||||
"Use 'diary' to read persistent run history before checking live logs."
|
||||
),
|
||||
},
|
||||
"last_n": {
|
||||
"type": "integer",
|
||||
@@ -3436,6 +3542,7 @@ def register_queen_lifecycle_tools(
|
||||
if phase_state is not None:
|
||||
phase_state.agent_path = str(resolved_path)
|
||||
await phase_state.switch_to_staging()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
|
||||
|
||||
worker_name = info.name if info else updated_session.worker_id
|
||||
return json.dumps(
|
||||
@@ -3555,6 +3662,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to running phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_running()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "running"})
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
|
||||
@@ -27,7 +27,14 @@ export default function MyAgents() {
|
||||
agentsApi
|
||||
.discover()
|
||||
.then((result) => {
|
||||
setAgents(result["Your Agents"] || []);
|
||||
const entries = result["Your Agents"] || [];
|
||||
entries.sort((a, b) => {
|
||||
if (!a.last_active && !b.last_active) return 0;
|
||||
if (!a.last_active) return 1;
|
||||
if (!b.last_active) return -1;
|
||||
return b.last_active.localeCompare(a.last_active);
|
||||
});
|
||||
setAgents(entries);
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err.message || "Failed to load agents");
|
||||
|
||||
@@ -252,6 +252,10 @@ function truncate(s: string, max: number): string {
|
||||
type SessionRestoreResult = {
|
||||
messages: ChatMessage[];
|
||||
restoredPhase: "planning" | "building" | "staging" | "running" | null;
|
||||
/** Last flowchart map from events — used to restore flowchart overlay on cold resume. */
|
||||
flowchartMap: Record<string, string[]> | null;
|
||||
/** Last original draft from events — used to restore flowchart overlay on cold resume. */
|
||||
originalDraft: DraftGraphData | null;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -268,6 +272,8 @@ async function restoreSessionMessages(
|
||||
if (events.length > 0) {
|
||||
const messages: ChatMessage[] = [];
|
||||
let runningPhase: ChatMessage["phase"] = undefined;
|
||||
let flowchartMap: Record<string, string[]> | null = null;
|
||||
let originalDraft: DraftGraphData | null = null;
|
||||
for (const evt of events) {
|
||||
// Track phase transitions so each message gets the phase it was created in
|
||||
const p = evt.type === "queen_phase_changed" ? evt.data?.phase as string
|
||||
@@ -276,6 +282,12 @@ async function restoreSessionMessages(
|
||||
if (p && ["planning", "building", "staging", "running"].includes(p)) {
|
||||
runningPhase = p as ChatMessage["phase"];
|
||||
}
|
||||
// Track last flowchart state for cold restore
|
||||
if (evt.type === "flowchart_map_updated" && evt.data) {
|
||||
const mapData = evt.data as { map?: Record<string, string[]>; original_draft?: DraftGraphData };
|
||||
flowchartMap = mapData.map ?? null;
|
||||
originalDraft = mapData.original_draft ?? null;
|
||||
}
|
||||
const msg = sseEventToChatMessage(evt, thread, agentDisplayName);
|
||||
if (!msg) continue;
|
||||
if (evt.stream_id === "queen") {
|
||||
@@ -284,12 +296,12 @@ async function restoreSessionMessages(
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
return { messages, restoredPhase: runningPhase ?? null };
|
||||
return { messages, restoredPhase: runningPhase ?? null, flowchartMap, originalDraft };
|
||||
}
|
||||
} catch {
|
||||
// Event log not available — session will start fresh.
|
||||
}
|
||||
return { messages: [], restoredPhase: null };
|
||||
return { messages: [], restoredPhase: null, flowchartMap: null, originalDraft: null };
|
||||
}
|
||||
|
||||
// --- Per-agent backend state (consolidated) ---
|
||||
@@ -799,6 +811,8 @@ export default function Workspace() {
|
||||
}
|
||||
|
||||
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
|
||||
let restoredFlowchartMap: Record<string, string[]> | null = null;
|
||||
let restoredOriginalDraft: DraftGraphData | null = null;
|
||||
if (!liveSession) {
|
||||
// Fetch conversation history from disk BEFORE creating the new session.
|
||||
// SKIP if messages were already pre-populated by handleHistoryOpen.
|
||||
@@ -810,9 +824,22 @@ export default function Workspace() {
|
||||
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
|
||||
preRestoredMsgs.push(...restored.messages);
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not available — will start fresh
|
||||
}
|
||||
} else if (restoreFrom && alreadyHasMessages) {
|
||||
// Messages already cached in localStorage — still fetch events for
|
||||
// non-message state (phase, flowchart) that isn't cached.
|
||||
try {
|
||||
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not critical — UI will still show cached messages
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress the queen's intro cycle whenever we are about to restore a
|
||||
@@ -835,7 +862,7 @@ export default function Workspace() {
|
||||
}));
|
||||
}
|
||||
restoredMessageCount = preRestoredMsgs.length;
|
||||
} else if (restoreFrom && activeId) {
|
||||
} else if (restoreFrom && activeId && !alreadyHasMessages) {
|
||||
// We had a stored session but no messages on disk — wipe stale localStorage cache
|
||||
setSessionsByAgent(prev => ({
|
||||
...prev,
|
||||
@@ -889,6 +916,9 @@ export default function Workspace() {
|
||||
queenReady: true,
|
||||
queenPhase: qPhase,
|
||||
queenBuilding: qPhase === "building",
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
@@ -963,6 +993,8 @@ export default function Workspace() {
|
||||
|
||||
// Track the last queen phase seen in the event log for cold restore
|
||||
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
|
||||
let restoredFlowchartMap: Record<string, string[]> | null = null;
|
||||
let restoredOriginalDraft: DraftGraphData | null = null;
|
||||
|
||||
if (!liveSession) {
|
||||
// Reconnect failed — clear stale cached messages from localStorage restore.
|
||||
@@ -990,6 +1022,19 @@ export default function Workspace() {
|
||||
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
|
||||
preQueenMsgs = restored.messages;
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} else if (coldRestoreId && alreadyHasMessages) {
|
||||
// Messages already cached — still fetch events for non-message state (phase, flowchart)
|
||||
try {
|
||||
const displayNameTemp = formatAgentDisplayName(agentPath);
|
||||
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not critical — UI will still show cached messages
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress intro whenever we are about to restore a previous conversation.
|
||||
@@ -1070,6 +1115,9 @@ export default function Workspace() {
|
||||
displayName,
|
||||
queenPhase: initialPhase,
|
||||
queenBuilding: initialPhase === "building",
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
|
||||
// Update the session label + backendSessionId. Also set historySourceId
|
||||
@@ -1107,6 +1155,11 @@ export default function Workspace() {
|
||||
if (historyId && !coldRestoreId) {
|
||||
const restored = await restoreSessionMessages(historyId, agentType, displayName);
|
||||
restoredMsgs.push(...restored.messages);
|
||||
// Use flowchart from event log if not already set
|
||||
if (restored.flowchartMap && !restoredFlowchartMap) {
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
}
|
||||
|
||||
// Check worker status (needed for isWorkerRunning flag)
|
||||
try {
|
||||
@@ -1149,6 +1202,9 @@ export default function Workspace() {
|
||||
loading: false,
|
||||
queenReady: !!(isResumedSession || hasRestoredContent),
|
||||
...(isWorkerRunning ? { workerRunState: "running" } : {}),
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
|
||||
@@ -33,6 +33,7 @@ API_KEY_PROVIDERS = [
|
||||
("TOGETHER_API_KEY", "Together AI", "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo"),
|
||||
("DEEPSEEK_API_KEY", "DeepSeek", "deepseek-chat"),
|
||||
("MINIMAX_API_KEY", "MiniMax", "MiniMax-M2.5"),
|
||||
("HIVE_API_KEY", "Hive LLM", "hive/queen"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_check_llm_key_module():
|
||||
module_path = Path(__file__).resolve().parents[2] / "scripts" / "check_llm_key.py"
|
||||
spec = importlib.util.spec_from_file_location("check_llm_key_script", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _run_openrouter_check(monkeypatch, status_code: int):
|
||||
module = _load_check_llm_key_module()
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, code):
|
||||
self.status_code = code
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout):
|
||||
calls["timeout"] = timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, endpoint, headers):
|
||||
calls["endpoint"] = endpoint
|
||||
calls["headers"] = headers
|
||||
return FakeResponse(status_code)
|
||||
|
||||
monkeypatch.setattr(module.httpx, "Client", FakeClient)
|
||||
result = module.check_openrouter("test-key")
|
||||
return result, calls
|
||||
|
||||
|
||||
def _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
status_code: int,
|
||||
payload: dict | None = None,
|
||||
model: str = "openai/gpt-4o-mini",
|
||||
):
|
||||
module = _load_check_llm_key_module()
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, code):
|
||||
self.status_code = code
|
||||
self._payload = payload
|
||||
self.text = ""
|
||||
|
||||
def json(self):
|
||||
if self._payload is None:
|
||||
raise ValueError("no json")
|
||||
return self._payload
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout):
|
||||
calls["timeout"] = timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, endpoint, headers):
|
||||
calls["endpoint"] = endpoint
|
||||
calls["headers"] = headers
|
||||
return FakeResponse(status_code)
|
||||
|
||||
monkeypatch.setattr(module.httpx, "Client", FakeClient)
|
||||
result = module.check_openrouter_model("test-key", model)
|
||||
return result, calls
|
||||
|
||||
|
||||
def test_check_openrouter_200(monkeypatch):
|
||||
result, calls = _run_openrouter_check(monkeypatch, 200)
|
||||
assert result == {"valid": True, "message": "OpenRouter API key valid"}
|
||||
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models"
|
||||
assert calls["headers"] == {"Authorization": "Bearer test-key"}
|
||||
|
||||
|
||||
def test_check_openrouter_401(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 401)
|
||||
assert result == {"valid": False, "message": "Invalid OpenRouter API key"}
|
||||
|
||||
|
||||
def test_check_openrouter_403(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 403)
|
||||
assert result == {"valid": False, "message": "OpenRouter API key lacks permissions"}
|
||||
|
||||
|
||||
def test_check_openrouter_429(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 429)
|
||||
assert result == {"valid": True, "message": "OpenRouter API key valid"}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200(monkeypatch):
|
||||
result, calls = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "openai/gpt-4o-mini",
|
||||
"canonical_slug": "openai/gpt-4o-mini",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: openai/gpt-4o-mini",
|
||||
"model": "openai/gpt-4o-mini",
|
||||
}
|
||||
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models/user"
|
||||
assert calls["headers"] == {"Authorization": "Bearer test-key"}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_matches_canonical_slug(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "mistralai/mistral-small-4",
|
||||
"canonical_slug": "mistralai/mistral-small-2603",
|
||||
}
|
||||
]
|
||||
},
|
||||
model="mistralai/mistral-small-2603",
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: mistralai/mistral-small-2603",
|
||||
"model": "mistralai/mistral-small-2603",
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_sanitizes_pasted_unicode(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "z-ai/glm-5-turbo",
|
||||
"canonical_slug": "z-ai/glm-5-turbo",
|
||||
}
|
||||
]
|
||||
},
|
||||
model="openrouter/z-ai\u200b/glm\u20115\u2011turbo",
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: z-ai/glm-5-turbo",
|
||||
"model": "z-ai/glm-5-turbo",
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_not_found_with_suggestions(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{"id": "z-ai/glm-5-turbo"},
|
||||
{"id": "z-ai/glm-4.6v"},
|
||||
]
|
||||
},
|
||||
model="z-ai/glm-5-turb",
|
||||
)
|
||||
assert result == {
|
||||
"valid": False,
|
||||
"message": (
|
||||
"OpenRouter model is not available for this key/settings: z-ai/glm-5-turb. "
|
||||
"Closest matches: z-ai/glm-5-turbo"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_404_with_error_message(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
404,
|
||||
{"error": {"message": "No endpoints available for this model"}},
|
||||
)
|
||||
assert result == {
|
||||
"valid": False,
|
||||
"message": (
|
||||
"OpenRouter model is not available for this key/settings: openai/gpt-4o-mini. "
|
||||
"No endpoints available for this model"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_429(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(monkeypatch, 429)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model check rate-limited; assuming model is reachable",
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from framework.config import get_hive_config
|
||||
from framework.config import get_api_base, get_hive_config, get_preferred_model
|
||||
|
||||
|
||||
class TestGetHiveConfig:
|
||||
@@ -21,3 +21,47 @@ class TestGetHiveConfig:
|
||||
assert result == {}
|
||||
assert "Failed to load Hive config" in caplog.text
|
||||
assert str(config_file) in caplog.text
|
||||
|
||||
|
||||
class TestOpenRouterConfig:
|
||||
"""OpenRouter config composition and fallback behavior."""
|
||||
|
||||
def test_get_preferred_model_for_openrouter(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
|
||||
|
||||
def test_get_preferred_model_normalizes_openrouter_prefixed_model(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"openrouter/x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
|
||||
|
||||
def test_get_api_base_falls_back_to_openrouter_default(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_get_api_base_keeps_explicit_openrouter_api_base(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta","api_base":"https://proxy.example/v1"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == "https://proxy.example/v1"
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
from framework.credentials import key_storage
|
||||
from framework.credentials.validation import ensure_credential_key_env
|
||||
|
||||
|
||||
def _install_fake_aden_modules(monkeypatch, check_fn, credential_specs):
|
||||
shell_config_module = ModuleType("aden_tools.credentials.shell_config")
|
||||
shell_config_module.check_env_var_in_shell_config = check_fn
|
||||
|
||||
credentials_module = ModuleType("aden_tools.credentials")
|
||||
credentials_module.CREDENTIAL_SPECS = credential_specs
|
||||
|
||||
monkeypatch.setitem(sys.modules, "aden_tools.credentials.shell_config", shell_config_module)
|
||||
monkeypatch.setitem(sys.modules, "aden_tools.credentials", credentials_module)
|
||||
|
||||
|
||||
def test_bootstrap_loads_configured_llm_env_var_from_shell_config(monkeypatch):
|
||||
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
|
||||
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"framework.config.get_hive_config",
|
||||
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
|
||||
)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
|
||||
calls = []
|
||||
|
||||
def check_env(var_name):
|
||||
calls.append(var_name)
|
||||
if var_name == "OPENROUTER_API_KEY":
|
||||
return True, "or-key-123"
|
||||
return False, None
|
||||
|
||||
_install_fake_aden_modules(
|
||||
monkeypatch,
|
||||
check_env,
|
||||
{"anthropic": SimpleNamespace(env_var="ANTHROPIC_API_KEY")},
|
||||
)
|
||||
|
||||
ensure_credential_key_env()
|
||||
|
||||
assert os.environ.get("OPENROUTER_API_KEY") == "or-key-123"
|
||||
assert "OPENROUTER_API_KEY" in calls
|
||||
|
||||
|
||||
def test_bootstrap_does_not_override_existing_configured_llm_env_var(monkeypatch):
|
||||
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
|
||||
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"framework.config.get_hive_config",
|
||||
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
|
||||
)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "already-set")
|
||||
|
||||
calls = []
|
||||
|
||||
def check_env(var_name):
|
||||
calls.append(var_name)
|
||||
return True, "new-value-should-not-apply"
|
||||
|
||||
_install_fake_aden_modules(monkeypatch, check_env, {})
|
||||
|
||||
ensure_credential_key_env()
|
||||
|
||||
assert os.environ.get("OPENROUTER_API_KEY") == "already-set"
|
||||
assert "OPENROUTER_API_KEY" not in calls
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Tests for default skills — parsing, token budget, and configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import (
|
||||
SHARED_MEMORY_KEYS,
|
||||
SKILL_REGISTRY,
|
||||
DefaultSkillManager,
|
||||
)
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
_DEFAULT_SKILLS_DIR = (
|
||||
Path(__file__).resolve().parent.parent / "framework" / "skills" / "_default_skills"
|
||||
)
|
||||
|
||||
|
||||
class TestDefaultSkillFiles:
|
||||
"""Verify all 6 built-in SKILL.md files parse correctly."""
|
||||
|
||||
def test_all_six_skills_exist(self):
|
||||
assert len(SKILL_REGISTRY) == 6
|
||||
|
||||
@pytest.mark.parametrize("skill_name,dir_name", list(SKILL_REGISTRY.items()))
|
||||
def test_skill_parses(self, skill_name, dir_name):
|
||||
path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
assert path.is_file(), f"Missing SKILL.md at {path}"
|
||||
|
||||
parsed = parse_skill_md(path, source_scope="framework")
|
||||
assert parsed is not None, f"Failed to parse {path}"
|
||||
assert parsed.name == skill_name
|
||||
assert parsed.description
|
||||
assert parsed.body
|
||||
assert parsed.source_scope == "framework"
|
||||
|
||||
def test_combined_token_budget(self):
|
||||
"""All default skill bodies combined should be under 2000 tokens (~8000 chars)."""
|
||||
total_chars = 0
|
||||
for dir_name in SKILL_REGISTRY.values():
|
||||
path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
parsed = parse_skill_md(path, source_scope="framework")
|
||||
assert parsed is not None
|
||||
total_chars += len(parsed.body)
|
||||
|
||||
approx_tokens = total_chars // 4
|
||||
assert approx_tokens < 2000, (
|
||||
f"Combined default skill bodies are ~{approx_tokens} tokens "
|
||||
f"({total_chars} chars), exceeding the 2000 token budget"
|
||||
)
|
||||
|
||||
def test_shared_memory_keys_all_prefixed(self):
|
||||
"""All shared memory keys must start with underscore."""
|
||||
for key in SHARED_MEMORY_KEYS:
|
||||
assert key.startswith("_"), f"Shared memory key missing _ prefix: {key}"
|
||||
|
||||
|
||||
class TestDefaultSkillManager:
|
||||
def test_load_all_defaults(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
|
||||
assert len(manager.active_skill_names) == 6
|
||||
for name in SKILL_REGISTRY:
|
||||
assert name in manager.active_skill_names
|
||||
|
||||
def test_load_idempotent(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
first_skills = dict(manager.active_skills)
|
||||
manager.load()
|
||||
assert manager.active_skills == first_skills
|
||||
|
||||
def test_build_protocols_prompt(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
|
||||
assert prompt.startswith("## Operational Protocols")
|
||||
# Should contain content from each active skill
|
||||
for name in SKILL_REGISTRY:
|
||||
skill = manager.active_skills[name]
|
||||
# At least some of the body should appear
|
||||
assert skill.body[:20] in prompt
|
||||
|
||||
def test_protocols_prompt_empty_when_all_disabled(self):
|
||||
config = SkillsConfig(all_defaults_disabled=True)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
|
||||
assert manager.build_protocols_prompt() == ""
|
||||
assert manager.active_skill_names == []
|
||||
|
||||
def test_disable_single_skill(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.quality-monitor": {"enabled": False}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
|
||||
assert "hive.quality-monitor" not in manager.active_skill_names
|
||||
assert len(manager.active_skill_names) == 5
|
||||
|
||||
def test_disable_all_via_convention(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills={"_all": {"enabled": False}})
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
|
||||
assert manager.active_skill_names == []
|
||||
|
||||
def test_log_active_skills(self, caplog):
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="framework.skills.defaults"):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
manager.log_active_skills()
|
||||
|
||||
assert "Default skills active:" in caplog.text
|
||||
|
||||
def test_log_all_disabled(self, caplog):
|
||||
import logging
|
||||
|
||||
config = SkillsConfig(all_defaults_disabled=True)
|
||||
with caplog.at_level(logging.INFO, logger="framework.skills.defaults"):
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
manager.log_active_skills()
|
||||
|
||||
assert "all disabled" in caplog.text
|
||||
|
||||
|
||||
class TestSkillsConfig:
|
||||
def test_default_is_enabled(self):
|
||||
config = SkillsConfig()
|
||||
assert config.is_default_enabled("hive.note-taking") is True
|
||||
|
||||
def test_explicit_disable(self):
|
||||
config = SkillsConfig(
|
||||
default_skills={"hive.note-taking": DefaultSkillConfig(enabled=False)}
|
||||
)
|
||||
assert config.is_default_enabled("hive.note-taking") is False
|
||||
assert config.is_default_enabled("hive.batch-ledger") is True
|
||||
|
||||
def test_all_disabled_flag(self):
|
||||
config = SkillsConfig(all_defaults_disabled=True)
|
||||
assert config.is_default_enabled("hive.note-taking") is False
|
||||
assert config.is_default_enabled("anything") is False
|
||||
|
||||
def test_from_agent_vars_basic(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={
|
||||
"hive.note-taking": {"enabled": True},
|
||||
"hive.quality-monitor": {"enabled": False},
|
||||
},
|
||||
skills=["deep-research"],
|
||||
)
|
||||
assert config.is_default_enabled("hive.note-taking") is True
|
||||
assert config.is_default_enabled("hive.quality-monitor") is False
|
||||
assert config.skills == ["deep-research"]
|
||||
|
||||
def test_from_agent_vars_bool_shorthand(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills={"hive.note-taking": False})
|
||||
assert config.is_default_enabled("hive.note-taking") is False
|
||||
|
||||
def test_from_agent_vars_all_disabled(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills={"_all": {"enabled": False}})
|
||||
assert config.all_defaults_disabled is True
|
||||
|
||||
def test_get_default_overrides(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={
|
||||
"hive.batch-ledger": {"enabled": True, "checkpoint_every_n": 10},
|
||||
}
|
||||
)
|
||||
overrides = config.get_default_overrides("hive.batch-ledger")
|
||||
assert overrides == {"checkpoint_every_n": 10}
|
||||
|
||||
def test_get_default_overrides_empty(self):
|
||||
config = SkillsConfig()
|
||||
assert config.get_default_overrides("hive.note-taking") == {}
|
||||
|
||||
def test_from_agent_vars_none_inputs(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills=None, skills=None)
|
||||
assert config.skills == []
|
||||
assert config.default_skills == {}
|
||||
assert config.all_defaults_disabled is False
|
||||
@@ -1530,6 +1530,34 @@ class TestTransientErrorRetry:
|
||||
await node.execute(ctx)
|
||||
assert llm._call_index == 1 # only tried once
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_non_transient_error_does_not_crash(
|
||||
self, runtime, node_spec, memory
|
||||
):
|
||||
"""Client-facing non-transient errors should wait for input, not crash on token vars."""
|
||||
node_spec.output_keys = []
|
||||
node_spec.client_facing = True
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ValueError("bad request: blocked by policy"),
|
||||
fail_count=100, # always fails
|
||||
success_scenario=text_scenario("unreachable"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=1,
|
||||
max_stream_retries=0,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
node._await_user_input = AsyncMock(return_value=None)
|
||||
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Max iterations" in (result.error or "")
|
||||
node._await_user_input.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_exhausts_retries(self, runtime, node_spec, memory):
|
||||
"""Transient errors that exhaust retries should raise."""
|
||||
|
||||
@@ -19,7 +19,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
|
||||
from framework.llm.litellm import (
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
|
||||
LiteLLMProvider,
|
||||
_compute_retry_delay,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
|
||||
@@ -72,6 +76,20 @@ class TestLiteLLMProviderInit:
|
||||
)
|
||||
assert provider.api_base == "https://proxy.example/v1"
|
||||
|
||||
def test_init_openrouter_defaults_api_base(self):
|
||||
"""OpenRouter should default to the official OpenAI-compatible endpoint."""
|
||||
provider = LiteLLMProvider(model="openrouter/x-ai/grok-4.20-beta", api_key="my-key")
|
||||
assert provider.api_base == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_init_openrouter_keeps_custom_api_base(self):
|
||||
"""Explicit api_base should win over OpenRouter defaults."""
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/x-ai/grok-4.20-beta",
|
||||
api_key="my-key",
|
||||
api_base="https://proxy.example/v1",
|
||||
)
|
||||
assert provider.api_base == "https://proxy.example/v1"
|
||||
|
||||
def test_init_ollama_no_key_needed(self):
|
||||
"""Test that Ollama models don't require API key."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
@@ -192,6 +210,34 @@ class TestToolConversion:
|
||||
assert result["function"]["parameters"]["properties"]["query"]["type"] == "string"
|
||||
assert result["function"]["parameters"]["required"] == ["query"]
|
||||
|
||||
def test_parse_tool_call_arguments_repairs_truncated_json(self):
|
||||
"""Truncated JSON fragments should be repaired into valid tool inputs."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
parsed = provider._parse_tool_call_arguments(
|
||||
(
|
||||
'{"question":"What story structure should the agent use?",'
|
||||
'"options":["3-act structure","Beginning-Middle-End","Random paragraph"'
|
||||
),
|
||||
"ask_user",
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
"question": "What story structure should the agent use?",
|
||||
"options": [
|
||||
"3-act structure",
|
||||
"Beginning-Middle-End",
|
||||
"Random paragraph",
|
||||
],
|
||||
}
|
||||
|
||||
def test_parse_tool_call_arguments_raises_when_unrepairable(self):
|
||||
"""Completely invalid JSON should fail fast instead of producing _raw loops."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
|
||||
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
|
||||
|
||||
|
||||
class TestAnthropicProviderBackwardCompatibility:
|
||||
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
|
||||
@@ -682,6 +728,315 @@ class TestMiniMaxStreamFallback:
|
||||
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
|
||||
|
||||
|
||||
class TestOpenRouterToolCompatFallback:
|
||||
"""OpenRouter models should fall back when native tool use is unavailable."""
|
||||
|
||||
def teardown_method(self):
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_falls_back_to_json_tool_emulation(self, mock_acompletion):
|
||||
"""OpenRouter tool-use 404s should emit synthetic ToolCallEvents instead of errors."""
|
||||
from framework.llm.stream_events import FinishEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"num_results": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
'{"assistant_response":"","tool_calls":['
|
||||
'{"name":"web_search","arguments":'
|
||||
'{"query":"Python 3.13 release notes","num_results":3}}'
|
||||
"]}"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 18
|
||||
compat_response.usage.completion_tokens = 9
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
"that support tool use. To learn more about provider routing, "
|
||||
'visit: https://openrouter.ai/docs/guides/routing/provider-selection",'
|
||||
'"code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Search for the Python 3.13 release notes."}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=256,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "web_search"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"query": "Python 3.13 release notes",
|
||||
"num_results": 3,
|
||||
}
|
||||
assert tool_calls[0].tool_use_id.startswith("openrouter_compat_")
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "tool_calls"
|
||||
assert finish_events[0].input_tokens == 18
|
||||
assert finish_events[0].output_tokens == 9
|
||||
|
||||
assert mock_acompletion.call_count == 2
|
||||
first_call = mock_acompletion.call_args_list[0].kwargs
|
||||
assert first_call["stream"] is True
|
||||
assert "tools" in first_call
|
||||
|
||||
second_call = mock_acompletion.call_args_list[1].kwargs
|
||||
assert "tools" not in second_call
|
||||
assert "Tool compatibility mode is active" in second_call["messages"][0]["content"]
|
||||
assert provider.model in OPENROUTER_TOOL_COMPAT_MODEL_CACHE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_parses_textual_tool_calls_and_uses_cache(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""Textual tool-call markers should become ToolCallEvents and skip repeat probing."""
|
||||
from framework.llm.stream_events import ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="ask_user_multiple",
|
||||
description="Ask the user a multiple-choice question",
|
||||
parameters={
|
||||
"properties": {
|
||||
"options": {"type": "array"},
|
||||
"question": {"type": "string"},
|
||||
"prompt": {"type": "string"},
|
||||
},
|
||||
"required": ["options", "question", "prompt"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
"<|tool_call_start|>"
|
||||
"[ask_user_multiple(options=['Quartet Collaborator', 'Project Advisor'], "
|
||||
"question='Who are you?', prompt='Who are you?')]"
|
||||
"<|tool_call_end|>"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 10
|
||||
compat_response.usage.completion_tokens = 5
|
||||
|
||||
call_state = {"count": 0}
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
call_state["count"] += 1
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
first_events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Who are you?"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
first_events.append(event)
|
||||
|
||||
tool_calls = [event for event in first_events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "ask_user_multiple"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"options": ["Quartet Collaborator", "Project Advisor"],
|
||||
"question": "Who are you?",
|
||||
"prompt": "Who are you?",
|
||||
}
|
||||
|
||||
second_events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Who are you?"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
second_events.append(event)
|
||||
|
||||
second_tool_calls = [event for event in second_events if isinstance(event, ToolCallEvent)]
|
||||
assert len(second_tool_calls) == 1
|
||||
assert mock_acompletion.call_count == 3
|
||||
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
|
||||
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
|
||||
assert "stream" not in mock_acompletion.call_args_list[2].kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_parses_plain_text_tool_call_lines(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""Plain textual tool-call lines should execute as tools, not user-visible text."""
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="ask_user",
|
||||
description="Ask the user a single multiple-choice question",
|
||||
parameters={
|
||||
"properties": {
|
||||
"question": {"type": "string"},
|
||||
"options": {"type": "array"},
|
||||
},
|
||||
"required": ["question", "options"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
"Queen has been loaded. It's ready to assist with your planning needs.\n\n"
|
||||
"ask_user('What would you like to do?', ['Define a new agent', "
|
||||
"'Diagnose an existing agent', 'Explore tools'])"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 11
|
||||
compat_response.usage.completion_tokens = 7
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "ask_user"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"question": "What would you like to do?",
|
||||
"options": ["Define a new agent", "Diagnose an existing agent", "Explore tools"],
|
||||
}
|
||||
|
||||
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
||||
assert len(text_events) == 1
|
||||
assert "ask_user(" not in text_events[0].snapshot
|
||||
assert text_events[0].snapshot == (
|
||||
"Queen has been loaded. It's ready to assist with your planning needs."
|
||||
)
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "tool_calls"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_treats_non_json_as_plain_text(self, mock_acompletion):
|
||||
"""If fallback output is not valid JSON, preserve it as assistant text."""
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = "I can answer directly without tools."
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 12
|
||||
compat_response.usage.completion_tokens = 6
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Say hello."}],
|
||||
system="Be concise.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
||||
assert len(text_events) == 1
|
||||
assert text_events[0].snapshot == "I can answer directly without tools."
|
||||
assert not any(isinstance(event, ToolCallEvent) for event in events)
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "stop"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentRunner._is_local_model — parameterized tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -21,3 +21,8 @@ def test_minimax_provider_prefix_maps_to_minimax_api_key():
|
||||
def test_minimax_model_name_prefix_maps_to_minimax_api_key():
|
||||
runner = _runner_for_unit_test()
|
||||
assert runner._get_api_key_env_var("minimax-chat") == "MINIMAX_API_KEY"
|
||||
|
||||
|
||||
def test_openrouter_provider_prefix_maps_to_openrouter_api_key():
|
||||
runner = _runner_for_unit_test()
|
||||
assert runner._get_api_key_env_var("openrouter/x-ai/grok-4.20-beta") == "OPENROUTER_API_KEY"
|
||||
|
||||
@@ -0,0 +1,520 @@
|
||||
"""Tests for safe_eval — the sandboxed expression evaluator used by edge conditions.
|
||||
|
||||
Covers: literals, data structures, arithmetic, comparisons, boolean logic
|
||||
(including short-circuit semantics), variable lookup, subscript/attribute
|
||||
access, whitelisted function calls, method calls, ternary expressions,
|
||||
chained comparisons, and security boundaries (private attrs, disallowed
|
||||
AST nodes, disallowed function calls).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Literals and constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLiterals:
|
||||
def test_integer(self):
|
||||
assert safe_eval("42") == 42
|
||||
|
||||
def test_negative_integer(self):
|
||||
assert safe_eval("-1") == -1
|
||||
|
||||
def test_float(self):
|
||||
assert safe_eval("3.14") == pytest.approx(3.14)
|
||||
|
||||
def test_string(self):
|
||||
assert safe_eval("'hello'") == "hello"
|
||||
|
||||
def test_double_quoted_string(self):
|
||||
assert safe_eval('"world"') == "world"
|
||||
|
||||
def test_boolean_true(self):
|
||||
assert safe_eval("True") is True
|
||||
|
||||
def test_boolean_false(self):
|
||||
assert safe_eval("False") is False
|
||||
|
||||
def test_none(self):
|
||||
assert safe_eval("None") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDataStructures:
|
||||
def test_list(self):
|
||||
assert safe_eval("[1, 2, 3]") == [1, 2, 3]
|
||||
|
||||
def test_empty_list(self):
|
||||
assert safe_eval("[]") == []
|
||||
|
||||
def test_nested_list(self):
|
||||
assert safe_eval("[[1, 2], [3, 4]]") == [[1, 2], [3, 4]]
|
||||
|
||||
def test_tuple(self):
|
||||
assert safe_eval("(1, 2, 3)") == (1, 2, 3)
|
||||
|
||||
def test_dict(self):
|
||||
assert safe_eval("{'a': 1, 'b': 2}") == {"a": 1, "b": 2}
|
||||
|
||||
def test_empty_dict(self):
|
||||
assert safe_eval("{}") == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Arithmetic and binary operators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestArithmetic:
|
||||
def test_addition(self):
|
||||
assert safe_eval("2 + 3") == 5
|
||||
|
||||
def test_subtraction(self):
|
||||
assert safe_eval("10 - 4") == 6
|
||||
|
||||
def test_multiplication(self):
|
||||
assert safe_eval("3 * 7") == 21
|
||||
|
||||
def test_division(self):
|
||||
assert safe_eval("10 / 4") == 2.5
|
||||
|
||||
def test_floor_division(self):
|
||||
assert safe_eval("10 // 3") == 3
|
||||
|
||||
def test_modulo(self):
|
||||
assert safe_eval("10 % 3") == 1
|
||||
|
||||
def test_power(self):
|
||||
assert safe_eval("2 ** 10") == 1024
|
||||
|
||||
def test_complex_expression(self):
|
||||
assert safe_eval("(2 + 3) * 4 - 1") == 19
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unary operators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnaryOps:
|
||||
def test_negation(self):
|
||||
assert safe_eval("-5") == -5
|
||||
|
||||
def test_positive(self):
|
||||
assert safe_eval("+5") == 5
|
||||
|
||||
def test_not_true(self):
|
||||
assert safe_eval("not True") is False
|
||||
|
||||
def test_not_false(self):
|
||||
assert safe_eval("not False") is True
|
||||
|
||||
def test_bitwise_invert(self):
|
||||
assert safe_eval("~0") == -1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Comparisons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComparisons:
|
||||
def test_equal(self):
|
||||
assert safe_eval("1 == 1") is True
|
||||
|
||||
def test_not_equal(self):
|
||||
assert safe_eval("1 != 2") is True
|
||||
|
||||
def test_less_than(self):
|
||||
assert safe_eval("1 < 2") is True
|
||||
|
||||
def test_greater_than(self):
|
||||
assert safe_eval("2 > 1") is True
|
||||
|
||||
def test_less_equal(self):
|
||||
assert safe_eval("2 <= 2") is True
|
||||
|
||||
def test_greater_equal(self):
|
||||
assert safe_eval("3 >= 2") is True
|
||||
|
||||
def test_is_none(self):
|
||||
assert safe_eval("x is None", {"x": None}) is True
|
||||
|
||||
def test_is_not_none(self):
|
||||
assert safe_eval("x is not None", {"x": 42}) is True
|
||||
|
||||
def test_in_list(self):
|
||||
assert safe_eval("'a' in x", {"x": ["a", "b", "c"]}) is True
|
||||
|
||||
def test_not_in_list(self):
|
||||
assert safe_eval("'z' not in x", {"x": ["a", "b"]}) is True
|
||||
|
||||
def test_chained_comparison(self):
|
||||
"""Chained comparisons like 1 < x < 10 should work."""
|
||||
assert safe_eval("1 < x < 10", {"x": 5}) is True
|
||||
|
||||
def test_chained_comparison_false(self):
|
||||
assert safe_eval("1 < x < 3", {"x": 5}) is False
|
||||
|
||||
def test_chained_three_way(self):
|
||||
assert safe_eval("0 <= x <= 100", {"x": 50}) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Boolean operators (with short-circuit semantics)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBooleanOps:
|
||||
def test_and_true(self):
|
||||
assert safe_eval("True and True") is True
|
||||
|
||||
def test_and_false(self):
|
||||
assert safe_eval("True and False") is False
|
||||
|
||||
def test_or_true(self):
|
||||
assert safe_eval("False or True") is True
|
||||
|
||||
def test_or_false(self):
|
||||
assert safe_eval("False or False") is False
|
||||
|
||||
def test_and_returns_last_truthy(self):
|
||||
"""Python `and` returns the last value if all truthy."""
|
||||
assert safe_eval("1 and 2 and 3") == 3
|
||||
|
||||
def test_and_returns_first_falsy(self):
|
||||
"""Python `and` returns the first falsy value."""
|
||||
assert safe_eval("1 and 0 and 3") == 0
|
||||
|
||||
def test_or_returns_first_truthy(self):
|
||||
"""Python `or` returns the first truthy value."""
|
||||
assert safe_eval("0 or '' or 42") == 42
|
||||
|
||||
def test_or_returns_last_falsy(self):
|
||||
"""Python `or` returns the last value if all falsy."""
|
||||
assert safe_eval("0 or '' or None") is None
|
||||
|
||||
def test_and_short_circuits(self):
|
||||
"""and should NOT evaluate the right side if left is falsy.
|
||||
|
||||
This is the bug we fixed — previously this would crash with
|
||||
TypeError because all operands were eagerly evaluated.
|
||||
"""
|
||||
# x is None, so `x.get("key")` would crash if evaluated
|
||||
assert safe_eval("x is not None and x.get('key')", {"x": None}) is False
|
||||
|
||||
def test_or_short_circuits(self):
|
||||
"""or should NOT evaluate the right side if left is truthy."""
|
||||
# x is truthy, so the crash-prone right side should never run
|
||||
assert safe_eval("x or y.get('missing')", {"x": "found", "y": {}}) == "found"
|
||||
|
||||
def test_and_guard_pattern_truthy(self):
|
||||
"""Guard pattern: check not None, then access — when value exists."""
|
||||
ctx = {"x": {"key": "value"}}
|
||||
assert safe_eval("x is not None and x.get('key')", ctx) == "value"
|
||||
|
||||
def test_multi_and(self):
|
||||
assert safe_eval("True and True and True") is True
|
||||
|
||||
def test_multi_or(self):
|
||||
assert safe_eval("False or False or True") is True
|
||||
|
||||
def test_mixed_and_or(self):
|
||||
assert safe_eval("True or False and False") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ternary (if/else) expressions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTernary:
|
||||
def test_ternary_true_branch(self):
|
||||
assert safe_eval("'yes' if True else 'no'") == "yes"
|
||||
|
||||
def test_ternary_false_branch(self):
|
||||
assert safe_eval("'yes' if False else 'no'") == "no"
|
||||
|
||||
def test_ternary_with_context(self):
|
||||
assert safe_eval("x * 2 if x > 0 else -x", {"x": 5}) == 10
|
||||
|
||||
def test_ternary_false_with_context(self):
|
||||
assert safe_eval("x * 2 if x > 0 else -x", {"x": -3}) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Variable lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVariables:
|
||||
def test_simple_variable(self):
|
||||
assert safe_eval("x", {"x": 42}) == 42
|
||||
|
||||
def test_string_variable(self):
|
||||
assert safe_eval("name", {"name": "Alice"}) == "Alice"
|
||||
|
||||
def test_dict_variable(self):
|
||||
ctx = {"output": {"status": "ok"}}
|
||||
assert safe_eval("output", ctx) == {"status": "ok"}
|
||||
|
||||
def test_undefined_variable_raises(self):
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("undefined_var")
|
||||
|
||||
def test_multiple_variables(self):
|
||||
assert safe_eval("x + y", {"x": 10, "y": 20}) == 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscript access (indexing)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubscript:
|
||||
def test_dict_subscript(self):
|
||||
assert safe_eval("d['key']", {"d": {"key": "value"}}) == "value"
|
||||
|
||||
def test_list_subscript(self):
|
||||
assert safe_eval("items[0]", {"items": [10, 20, 30]}) == 10
|
||||
|
||||
def test_nested_subscript(self):
|
||||
ctx = {"data": {"users": [{"name": "Alice"}]}}
|
||||
assert safe_eval("data['users'][0]['name']", ctx) == "Alice"
|
||||
|
||||
def test_missing_key_raises(self):
|
||||
with pytest.raises(KeyError):
|
||||
safe_eval("d['missing']", {"d": {}})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attribute access
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAttributeAccess:
|
||||
def test_private_attr_blocked(self):
|
||||
"""Attributes starting with _ must be blocked for security."""
|
||||
with pytest.raises(ValueError, match="private attribute"):
|
||||
safe_eval("x.__class__", {"x": 42})
|
||||
|
||||
def test_dunder_blocked(self):
|
||||
with pytest.raises(ValueError, match="private attribute"):
|
||||
safe_eval("x.__dict__", {"x": {}})
|
||||
|
||||
def test_single_underscore_blocked(self):
|
||||
with pytest.raises(ValueError, match="private attribute"):
|
||||
safe_eval("x._internal", {"x": {}})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Whitelisted function calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFunctionCalls:
|
||||
def test_len(self):
|
||||
assert safe_eval("len(x)", {"x": [1, 2, 3]}) == 3
|
||||
|
||||
def test_int_conversion(self):
|
||||
assert safe_eval("int('42')") == 42
|
||||
|
||||
def test_float_conversion(self):
|
||||
assert safe_eval("float('3.14')") == pytest.approx(3.14)
|
||||
|
||||
def test_str_conversion(self):
|
||||
assert safe_eval("str(42)") == "42"
|
||||
|
||||
def test_bool_conversion(self):
|
||||
assert safe_eval("bool(1)") is True
|
||||
|
||||
def test_abs(self):
|
||||
assert safe_eval("abs(-5)") == 5
|
||||
|
||||
def test_min(self):
|
||||
assert safe_eval("min(3, 1, 2)") == 1
|
||||
|
||||
def test_max(self):
|
||||
assert safe_eval("max(3, 1, 2)") == 3
|
||||
|
||||
def test_sum(self):
|
||||
assert safe_eval("sum(x)", {"x": [1, 2, 3]}) == 6
|
||||
|
||||
def test_round(self):
|
||||
assert safe_eval("round(3.7)") == 4
|
||||
|
||||
def test_all(self):
|
||||
assert safe_eval("all([True, True, True])") is True
|
||||
|
||||
def test_any(self):
|
||||
assert safe_eval("any([False, False, True])") is True
|
||||
|
||||
def test_list_constructor(self):
|
||||
assert safe_eval("list(x)", {"x": (1, 2, 3)}) == [1, 2, 3]
|
||||
|
||||
def test_dict_constructor(self):
|
||||
assert safe_eval("dict(a=1, b=2)") == {"a": 1, "b": 2}
|
||||
|
||||
def test_tuple_constructor(self):
|
||||
assert safe_eval("tuple(x)", {"x": [1, 2]}) == (1, 2)
|
||||
|
||||
def test_set_constructor(self):
|
||||
assert safe_eval("set(x)", {"x": [1, 2, 2, 3]}) == {1, 2, 3}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Whitelisted method calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMethodCalls:
|
||||
def test_dict_get(self):
|
||||
assert safe_eval("d.get('key', 'default')", {"d": {"key": "val"}}) == "val"
|
||||
|
||||
def test_dict_get_missing(self):
|
||||
assert safe_eval("d.get('missing', 'default')", {"d": {}}) == "default"
|
||||
|
||||
def test_dict_keys(self):
|
||||
result = safe_eval("list(d.keys())", {"d": {"a": 1, "b": 2}})
|
||||
assert sorted(result) == ["a", "b"]
|
||||
|
||||
def test_dict_values(self):
|
||||
result = safe_eval("list(d.values())", {"d": {"a": 1, "b": 2}})
|
||||
assert sorted(result) == [1, 2]
|
||||
|
||||
def test_string_lower(self):
|
||||
assert safe_eval("s.lower()", {"s": "HELLO"}) == "hello"
|
||||
|
||||
def test_string_upper(self):
|
||||
assert safe_eval("s.upper()", {"s": "hello"}) == "HELLO"
|
||||
|
||||
def test_string_strip(self):
|
||||
assert safe_eval("s.strip()", {"s": " hi "}) == "hi"
|
||||
|
||||
def test_string_split(self):
|
||||
assert safe_eval("s.split(',')", {"s": "a,b,c"}) == ["a", "b", "c"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: disallowed operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecurity:
|
||||
def test_import_blocked(self):
|
||||
"""__import__ is not in context, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("__import__('os')")
|
||||
|
||||
def test_lambda_blocked(self):
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
safe_eval("(lambda: 1)()")
|
||||
|
||||
def test_comprehension_blocked(self):
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
safe_eval("[x for x in range(10)]")
|
||||
|
||||
def test_assignment_blocked(self):
|
||||
"""Assignment expressions should not parse in eval mode."""
|
||||
with pytest.raises(SyntaxError):
|
||||
safe_eval("x = 5")
|
||||
|
||||
def test_disallowed_function_blocked(self):
|
||||
"""eval is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("eval('1+1')")
|
||||
|
||||
def test_exec_blocked(self):
|
||||
"""exec is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("exec('x=1')")
|
||||
|
||||
def test_type_call_blocked(self):
|
||||
"""type is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("type(42)")
|
||||
|
||||
def test_getattr_builtin_blocked(self):
|
||||
"""getattr is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("getattr(x, '__class__')", {"x": 42})
|
||||
|
||||
def test_empty_expression_raises(self):
|
||||
with pytest.raises(SyntaxError):
|
||||
safe_eval("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real-world edge condition patterns (from graph executor usage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeConditionPatterns:
|
||||
"""Patterns commonly used in EdgeSpec.condition_expr."""
|
||||
|
||||
def test_output_key_exists_and_not_none(self):
|
||||
ctx = {"output": {"approved_contacts": ["alice@example.com"]}}
|
||||
assert safe_eval("output.get('approved_contacts') is not None", ctx) is True
|
||||
|
||||
def test_output_key_missing(self):
|
||||
ctx = {"output": {}}
|
||||
assert safe_eval("output.get('approved_contacts') is not None", ctx) is False
|
||||
|
||||
def test_output_key_check_with_fallback(self):
|
||||
ctx = {"output": {"redo_extraction": True}}
|
||||
assert safe_eval("output.get('redo_extraction') is not None", ctx) is True
|
||||
|
||||
def test_guard_then_length_check(self):
|
||||
"""Guard pattern: check key exists, then check length."""
|
||||
ctx = {"output": {"results": [1, 2, 3]}}
|
||||
assert (
|
||||
safe_eval(
|
||||
"output.get('results') is not None and len(output['results']) > 0",
|
||||
ctx,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_guard_short_circuits_on_none(self):
|
||||
"""Guard pattern: short-circuit prevents crash on None."""
|
||||
ctx = {"output": {}}
|
||||
assert (
|
||||
safe_eval(
|
||||
"output.get('results') is not None and len(output['results']) > 0",
|
||||
ctx,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_success_flag_check(self):
|
||||
ctx = {"output": {"success": True}, "memory": {"attempts": 2}}
|
||||
assert safe_eval("output.get('success') == True", ctx) is True
|
||||
|
||||
def test_memory_threshold(self):
|
||||
ctx = {"memory": {"score": 0.85}}
|
||||
assert safe_eval("memory.get('score', 0) >= 0.8", ctx) is True
|
||||
|
||||
def test_string_contains_check(self):
|
||||
ctx = {"output": {"status": "completed_with_warnings"}}
|
||||
assert safe_eval("'completed' in output.get('status', '')", ctx) is True
|
||||
|
||||
def test_fallback_chain(self):
|
||||
"""or-chain for fallback values."""
|
||||
ctx = {"output": {}}
|
||||
result = safe_eval(
|
||||
"output.get('primary') or output.get('secondary') or 'default'",
|
||||
ctx,
|
||||
)
|
||||
assert result == "default"
|
||||
|
||||
def test_no_context_needed(self):
|
||||
"""Some edges use constant expressions."""
|
||||
assert safe_eval("True") is True
|
||||
assert safe_eval("1 == 1") is True
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Tests for the skill catalog and prompt generation."""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
|
||||
def _make_skill(
|
||||
name: str = "my-skill",
|
||||
description: str = "A test skill.",
|
||||
source_scope: str = "project",
|
||||
body: str = "Instructions here.",
|
||||
location: str = "/tmp/skills/my-skill/SKILL.md",
|
||||
base_dir: str = "/tmp/skills/my-skill",
|
||||
) -> ParsedSkill:
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=description,
|
||||
location=location,
|
||||
base_dir=base_dir,
|
||||
source_scope=source_scope,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
class TestSkillCatalog:
|
||||
def test_add_and_get(self):
|
||||
catalog = SkillCatalog()
|
||||
skill = _make_skill()
|
||||
catalog.add(skill)
|
||||
|
||||
assert catalog.get("my-skill") is skill
|
||||
assert catalog.get("nonexistent") is None
|
||||
assert catalog.skill_count == 1
|
||||
|
||||
def test_init_with_skills_list(self):
|
||||
skills = [_make_skill("a", "Skill A"), _make_skill("b", "Skill B")]
|
||||
catalog = SkillCatalog(skills)
|
||||
|
||||
assert catalog.skill_count == 2
|
||||
assert catalog.get("a") is not None
|
||||
assert catalog.get("b") is not None
|
||||
|
||||
def test_activation_tracking(self):
|
||||
catalog = SkillCatalog([_make_skill()])
|
||||
assert not catalog.is_activated("my-skill")
|
||||
|
||||
catalog.mark_activated("my-skill")
|
||||
assert catalog.is_activated("my-skill")
|
||||
|
||||
def test_allowlisted_dirs(self):
|
||||
skills = [
|
||||
_make_skill("a", base_dir="/skills/a"),
|
||||
_make_skill("b", base_dir="/skills/b"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
dirs = catalog.allowlisted_dirs
|
||||
|
||||
assert "/skills/a" in dirs
|
||||
assert "/skills/b" in dirs
|
||||
|
||||
def test_to_prompt_empty_catalog(self):
|
||||
catalog = SkillCatalog()
|
||||
assert catalog.to_prompt() == ""
|
||||
|
||||
def test_to_prompt_framework_only(self):
|
||||
"""Framework-scope skills should NOT appear in the catalog prompt."""
|
||||
catalog = SkillCatalog([_make_skill(source_scope="framework")])
|
||||
assert catalog.to_prompt() == ""
|
||||
|
||||
def test_to_prompt_xml_generation(self):
|
||||
skills = [
|
||||
_make_skill("alpha", "Alpha skill", "project", location="/p/alpha/SKILL.md"),
|
||||
_make_skill("beta", "Beta skill", "user", location="/u/beta/SKILL.md"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "<available_skills>" in prompt
|
||||
assert "</available_skills>" in prompt
|
||||
assert "<name>alpha</name>" in prompt
|
||||
assert "<name>beta</name>" in prompt
|
||||
assert "<description>Alpha skill</description>" in prompt
|
||||
assert "<location>/p/alpha/SKILL.md</location>" in prompt
|
||||
|
||||
def test_to_prompt_sorted_by_name(self):
|
||||
skills = [
|
||||
_make_skill("zebra", "Z skill", "project"),
|
||||
_make_skill("alpha", "A skill", "project"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
alpha_pos = prompt.index("alpha")
|
||||
zebra_pos = prompt.index("zebra")
|
||||
assert alpha_pos < zebra_pos
|
||||
|
||||
def test_to_prompt_xml_escaping(self):
|
||||
skill = _make_skill("test", 'Has <special> & "chars"', "project")
|
||||
catalog = SkillCatalog([skill])
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "<special>" in prompt
|
||||
assert "&" in prompt
|
||||
|
||||
def test_to_prompt_excludes_framework_includes_others(self):
|
||||
"""Mixed scopes: only framework skills are excluded from catalog."""
|
||||
skills = [
|
||||
_make_skill("proj", "Project skill", "project"),
|
||||
_make_skill("usr", "User skill", "user"),
|
||||
_make_skill("fw", "Framework skill", "framework"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "<name>proj</name>" in prompt
|
||||
assert "<name>usr</name>" in prompt
|
||||
assert "fw" not in prompt
|
||||
|
||||
def test_to_prompt_contains_behavioral_instruction(self):
|
||||
catalog = SkillCatalog([_make_skill(source_scope="project")])
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "When a task matches a skill's description" in prompt
|
||||
assert "SKILL.md" in prompt
|
||||
|
||||
def test_build_pre_activated_prompt(self):
|
||||
skill = _make_skill("research", body="## Deep Research\nDo thorough research.")
|
||||
catalog = SkillCatalog([skill])
|
||||
prompt = catalog.build_pre_activated_prompt(["research"])
|
||||
|
||||
assert "Pre-Activated Skill: research" in prompt
|
||||
assert "## Deep Research" in prompt
|
||||
assert catalog.is_activated("research")
|
||||
|
||||
def test_build_pre_activated_skips_already_activated(self):
|
||||
skill = _make_skill("research", body="Research body")
|
||||
catalog = SkillCatalog([skill])
|
||||
catalog.mark_activated("research")
|
||||
|
||||
prompt = catalog.build_pre_activated_prompt(["research"])
|
||||
assert prompt == ""
|
||||
|
||||
def test_build_pre_activated_missing_skill(self):
|
||||
catalog = SkillCatalog()
|
||||
prompt = catalog.build_pre_activated_prompt(["nonexistent"])
|
||||
assert prompt == ""
|
||||
|
||||
def test_build_pre_activated_multiple(self):
|
||||
skills = [
|
||||
_make_skill("a", body="Body A"),
|
||||
_make_skill("b", body="Body B"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
prompt = catalog.build_pre_activated_prompt(["a", "b"])
|
||||
|
||||
assert "Pre-Activated Skill: a" in prompt
|
||||
assert "Body A" in prompt
|
||||
assert "Pre-Activated Skill: b" in prompt
|
||||
assert "Body B" in prompt
|
||||
assert catalog.is_activated("a")
|
||||
assert catalog.is_activated("b")
|
||||
|
||||
def test_duplicate_add_overwrites(self):
|
||||
"""Adding a skill with the same name replaces the previous one."""
|
||||
catalog = SkillCatalog()
|
||||
catalog.add(_make_skill("x", "First"))
|
||||
catalog.add(_make_skill("x", "Second"))
|
||||
|
||||
assert catalog.skill_count == 1
|
||||
assert catalog.get("x").description == "Second"
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Tests for skill discovery."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
|
||||
|
||||
def _write_skill(base: Path, name: str, description: str = "A test skill.") -> Path:
|
||||
"""Create a minimal skill directory with SKILL.md."""
|
||||
skill_dir = base / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
f"---\nname: {name}\ndescription: {description}\n---\n\nInstructions.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return skill_dir
|
||||
|
||||
|
||||
class TestSkillDiscovery:
|
||||
def test_discover_project_skills(self, tmp_path):
|
||||
# Create project-level skills
|
||||
agents_skills = tmp_path / ".agents" / "skills"
|
||||
_write_skill(agents_skills, "skill-a")
|
||||
_write_skill(agents_skills, "skill-b")
|
||||
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
|
||||
names = {s.name for s in skills}
|
||||
assert "skill-a" in names
|
||||
assert "skill-b" in names
|
||||
assert all(s.source_scope == "project" for s in skills)
|
||||
|
||||
def test_hive_skills_path(self, tmp_path):
|
||||
hive_skills = tmp_path / ".hive" / "skills"
|
||||
_write_skill(hive_skills, "hive-skill")
|
||||
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
|
||||
assert len(skills) == 1
|
||||
assert skills[0].name == "hive-skill"
|
||||
|
||||
def test_collision_project_overrides_user(self, tmp_path, monkeypatch):
|
||||
# User-level skill
|
||||
user_skills = tmp_path / "home" / ".agents" / "skills"
|
||||
_write_skill(user_skills, "shared-skill", "User version")
|
||||
|
||||
# Project-level skill with same name
|
||||
project_skills = tmp_path / "project" / ".agents" / "skills"
|
||||
_write_skill(project_skills, "shared-skill", "Project version")
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "home")
|
||||
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path / "project",
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
|
||||
matching = [s for s in skills if s.name == "shared-skill"]
|
||||
assert len(matching) == 1
|
||||
assert matching[0].description == "Project version"
|
||||
|
||||
def test_collision_hive_overrides_agents(self, tmp_path):
|
||||
# Cross-client path
|
||||
agents_skills = tmp_path / ".agents" / "skills"
|
||||
_write_skill(agents_skills, "override-test", "Agents version")
|
||||
|
||||
# Hive-specific path (higher precedence)
|
||||
hive_skills = tmp_path / ".hive" / "skills"
|
||||
_write_skill(hive_skills, "override-test", "Hive version")
|
||||
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
|
||||
matching = [s for s in skills if s.name == "override-test"]
|
||||
assert len(matching) == 1
|
||||
assert matching[0].description == "Hive version"
|
||||
|
||||
def test_skips_git_and_node_modules(self, tmp_path):
|
||||
skills_dir = tmp_path / ".agents" / "skills"
|
||||
_write_skill(skills_dir / ".git", "git-skill")
|
||||
_write_skill(skills_dir / "node_modules", "npm-skill")
|
||||
_write_skill(skills_dir, "real-skill")
|
||||
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
|
||||
names = {s.name for s in skills}
|
||||
assert "real-skill" in names
|
||||
assert "git-skill" not in names
|
||||
assert "npm-skill" not in names
|
||||
|
||||
def test_empty_scan(self, tmp_path):
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
assert skills == []
|
||||
|
||||
def test_framework_scope_loads_defaults(self):
|
||||
"""Framework scope should find the built-in default skills."""
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
skip_user_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
|
||||
framework_skills = [s for s in skills if s.source_scope == "framework"]
|
||||
names = {s.name for s in framework_skills}
|
||||
assert "hive.note-taking" in names
|
||||
assert "hive.batch-ledger" in names
|
||||
|
||||
def test_max_depth_limit(self, tmp_path):
|
||||
# Create a skill nested beyond max_depth
|
||||
deep = tmp_path / ".agents" / "skills" / "a" / "b" / "c" / "d" / "e"
|
||||
_write_skill(deep, "too-deep")
|
||||
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
max_depth=2,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
assert not any(s.name == "too-deep" for s in skills)
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Integration tests for the skill system — prompt composition and backward compatibility."""
|
||||
|
||||
from framework.graph.prompt_composer import compose_system_prompt
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
|
||||
def _make_skill(
|
||||
name: str = "test-skill",
|
||||
description: str = "A test skill.",
|
||||
source_scope: str = "project",
|
||||
body: str = "Skill instructions.",
|
||||
location: str = "/tmp/skills/test-skill/SKILL.md",
|
||||
base_dir: str = "/tmp/skills/test-skill",
|
||||
) -> ParsedSkill:
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=description,
|
||||
location=location,
|
||||
base_dir=base_dir,
|
||||
source_scope=source_scope,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptComposition:
|
||||
"""Test that skill prompts integrate correctly with compose_system_prompt."""
|
||||
|
||||
def test_backward_compat_no_skill_params(self):
|
||||
"""compose_system_prompt works without skill params (backward compat)."""
|
||||
prompt = compose_system_prompt(
|
||||
identity_prompt="You are a helpful agent.",
|
||||
focus_prompt="Focus on the task.",
|
||||
)
|
||||
assert "You are a helpful agent." in prompt
|
||||
assert "Focus on the task." in prompt
|
||||
assert "Current date and time" in prompt
|
||||
|
||||
def test_skills_catalog_in_prompt(self):
|
||||
catalog = SkillCatalog([_make_skill(source_scope="project")])
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
|
||||
prompt = compose_system_prompt(
|
||||
identity_prompt="You are an agent.",
|
||||
focus_prompt=None,
|
||||
skills_catalog_prompt=catalog_prompt,
|
||||
)
|
||||
assert "<available_skills>" in prompt
|
||||
assert "<name>test-skill</name>" in prompt
|
||||
|
||||
def test_protocols_in_prompt(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
protocols_prompt = manager.build_protocols_prompt()
|
||||
|
||||
prompt = compose_system_prompt(
|
||||
identity_prompt="You are an agent.",
|
||||
focus_prompt=None,
|
||||
protocols_prompt=protocols_prompt,
|
||||
)
|
||||
assert "## Operational Protocols" in prompt
|
||||
|
||||
def test_full_prompt_ordering(self):
|
||||
"""Verify the three-layer onion ordering with all sections present."""
|
||||
catalog = SkillCatalog([_make_skill(source_scope="project")])
|
||||
|
||||
prompt = compose_system_prompt(
|
||||
identity_prompt="IDENTITY_SECTION",
|
||||
focus_prompt="FOCUS_SECTION",
|
||||
narrative="NARRATIVE_SECTION",
|
||||
accounts_prompt="ACCOUNTS_SECTION",
|
||||
skills_catalog_prompt=catalog.to_prompt(),
|
||||
protocols_prompt="PROTOCOLS_SECTION",
|
||||
)
|
||||
|
||||
identity_pos = prompt.index("IDENTITY_SECTION")
|
||||
accounts_pos = prompt.index("ACCOUNTS_SECTION")
|
||||
skills_pos = prompt.index("available_skills")
|
||||
protocols_pos = prompt.index("PROTOCOLS_SECTION")
|
||||
narrative_pos = prompt.index("NARRATIVE_SECTION")
|
||||
focus_pos = prompt.index("FOCUS_SECTION")
|
||||
|
||||
# Identity → Accounts → Skills → Protocols → Narrative → Focus
|
||||
assert identity_pos < accounts_pos
|
||||
assert accounts_pos < skills_pos
|
||||
assert skills_pos < protocols_pos
|
||||
assert protocols_pos < narrative_pos
|
||||
assert narrative_pos < focus_pos
|
||||
|
||||
def test_none_skill_prompts_excluded(self):
|
||||
"""None values for skill prompts should not add content."""
|
||||
prompt = compose_system_prompt(
|
||||
identity_prompt="Hello",
|
||||
focus_prompt=None,
|
||||
skills_catalog_prompt=None,
|
||||
protocols_prompt=None,
|
||||
)
|
||||
assert "available_skills" not in prompt
|
||||
assert "Operational Protocols" not in prompt
|
||||
|
||||
def test_empty_skill_prompts_excluded(self):
|
||||
"""Empty string skill prompts should not add content."""
|
||||
prompt = compose_system_prompt(
|
||||
identity_prompt="Hello",
|
||||
focus_prompt=None,
|
||||
skills_catalog_prompt="",
|
||||
protocols_prompt="",
|
||||
)
|
||||
assert "available_skills" not in prompt
|
||||
assert "Operational Protocols" not in prompt
|
||||
|
||||
|
||||
class TestEndToEndPipeline:
|
||||
"""Test the full discovery → catalog → prompt pipeline."""
|
||||
|
||||
def test_discovery_to_catalog_to_prompt(self, tmp_path):
|
||||
# Create a project skill
|
||||
skill_dir = tmp_path / ".agents" / "skills" / "my-tool"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: my-tool\ndescription: Tool for testing.\n---\n\n"
|
||||
"## Usage\nUse this tool when testing.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Discovery
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
skills = discovery.discover()
|
||||
assert len(skills) == 1
|
||||
|
||||
# Catalog
|
||||
catalog = SkillCatalog(skills)
|
||||
assert catalog.skill_count == 1
|
||||
|
||||
# Prompt generation
|
||||
prompt = catalog.to_prompt()
|
||||
assert "<name>my-tool</name>" in prompt
|
||||
assert "<description>Tool for testing.</description>" in prompt
|
||||
|
||||
# Pre-activation
|
||||
activated = catalog.build_pre_activated_prompt(["my-tool"])
|
||||
assert "## Usage" in activated
|
||||
assert catalog.is_activated("my-tool")
|
||||
|
||||
def test_defaults_plus_community_skills(self, tmp_path):
|
||||
"""Default skills and community skills produce separate prompt sections."""
|
||||
# Create a community skill
|
||||
skill_dir = tmp_path / ".agents" / "skills" / "community-skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: community-skill\ndescription: A community skill.\n---\n\nDo stuff.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Discover community skills
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
community_skills = discovery.discover()
|
||||
catalog = SkillCatalog(community_skills)
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
|
||||
# Load default skills
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
protocols_prompt = manager.build_protocols_prompt()
|
||||
|
||||
# Compose
|
||||
prompt = compose_system_prompt(
|
||||
identity_prompt="Agent identity.",
|
||||
focus_prompt=None,
|
||||
skills_catalog_prompt=catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
)
|
||||
|
||||
# Both sections present
|
||||
assert "<available_skills>" in prompt
|
||||
assert "<name>community-skill</name>" in prompt
|
||||
assert "## Operational Protocols" in prompt
|
||||
|
||||
def test_config_disables_defaults_keeps_community(self, tmp_path):
|
||||
"""Disabling all defaults should still allow community skills."""
|
||||
skill_dir = tmp_path / ".agents" / "skills" / "still-here"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: still-here\ndescription: Survives config.\n---\n\nBody.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Community skills
|
||||
discovery = SkillDiscovery(
|
||||
DiscoveryConfig(
|
||||
project_root=tmp_path,
|
||||
skip_user_scope=True,
|
||||
skip_framework_scope=True,
|
||||
)
|
||||
)
|
||||
catalog = SkillCatalog(discovery.discover())
|
||||
|
||||
# Disabled defaults
|
||||
config = SkillsConfig(all_defaults_disabled=True)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
protocols_prompt = manager.build_protocols_prompt()
|
||||
|
||||
assert "<name>still-here</name>" in catalog_prompt
|
||||
assert protocols_prompt == ""
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Tests for SKILL.md parser."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_skill(tmp_path):
|
||||
"""Helper to create a SKILL.md file and return its path."""
|
||||
|
||||
def _create(content: str, dir_name: str = "my-skill") -> Path:
|
||||
skill_dir = tmp_path / dir_name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
skill_md.write_text(content, encoding="utf-8")
|
||||
return skill_md
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
class TestParseSkillMd:
|
||||
def test_happy_path(self, tmp_skill):
|
||||
content = """---
|
||||
name: my-skill
|
||||
description: A test skill for unit testing.
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Instructions
|
||||
|
||||
Do the thing.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content), source_scope="project")
|
||||
assert result is not None
|
||||
assert result.name == "my-skill"
|
||||
assert result.description == "A test skill for unit testing."
|
||||
assert result.license == "MIT"
|
||||
assert result.source_scope == "project"
|
||||
assert "Do the thing." in result.body
|
||||
|
||||
def test_missing_description_returns_none(self, tmp_skill):
|
||||
content = """---
|
||||
name: no-desc
|
||||
---
|
||||
|
||||
Body here.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "no-desc"))
|
||||
assert result is None
|
||||
|
||||
def test_missing_name_uses_directory(self, tmp_skill):
|
||||
content = """---
|
||||
description: Skill without a name field.
|
||||
---
|
||||
|
||||
Body.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "fallback-dir"))
|
||||
assert result is not None
|
||||
assert result.name == "fallback-dir"
|
||||
|
||||
def test_empty_file_returns_none(self, tmp_skill):
|
||||
result = parse_skill_md(tmp_skill("", "empty"))
|
||||
assert result is None
|
||||
|
||||
def test_no_frontmatter_delimiters_returns_none(self, tmp_skill):
|
||||
content = "Just plain text without YAML frontmatter."
|
||||
result = parse_skill_md(tmp_skill(content, "no-yaml"))
|
||||
assert result is None
|
||||
|
||||
def test_unparseable_yaml_returns_none(self, tmp_skill):
|
||||
content = """---
|
||||
name: [invalid yaml
|
||||
- broken: {{
|
||||
---
|
||||
|
||||
Body.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "bad-yaml"))
|
||||
assert result is None
|
||||
|
||||
def test_unquoted_colon_fixup(self, tmp_skill):
|
||||
content = """---
|
||||
name: colon-test
|
||||
description: Use for: research tasks
|
||||
---
|
||||
|
||||
Body.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "colon-test"))
|
||||
assert result is not None
|
||||
assert "research tasks" in result.description
|
||||
|
||||
def test_long_name_warns_but_loads(self, tmp_skill):
|
||||
long_name = "a" * 100
|
||||
content = f"""---
|
||||
name: {long_name}
|
||||
description: A skill with an excessively long name.
|
||||
---
|
||||
|
||||
Body.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "long-name"))
|
||||
assert result is not None
|
||||
assert result.name == long_name
|
||||
|
||||
def test_name_mismatch_warns_but_loads(self, tmp_skill):
|
||||
content = """---
|
||||
name: different-name
|
||||
description: Name doesn't match directory.
|
||||
---
|
||||
|
||||
Body.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "actual-dir"))
|
||||
assert result is not None
|
||||
assert result.name == "different-name"
|
||||
|
||||
def test_optional_fields(self, tmp_skill):
|
||||
content = """---
|
||||
name: full-skill
|
||||
description: Skill with all optional fields.
|
||||
license: Apache-2.0
|
||||
compatibility:
|
||||
- claude-code
|
||||
- cursor
|
||||
metadata:
|
||||
author: tester
|
||||
version: "1.0"
|
||||
allowed-tools:
|
||||
- web_search
|
||||
- read_file
|
||||
---
|
||||
|
||||
Instructions here.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "full-skill"))
|
||||
assert result is not None
|
||||
assert result.license == "Apache-2.0"
|
||||
assert result.compatibility == ["claude-code", "cursor"]
|
||||
assert result.metadata == {"author": "tester", "version": "1.0"}
|
||||
assert result.allowed_tools == ["web_search", "read_file"]
|
||||
|
||||
def test_body_extraction(self, tmp_skill):
|
||||
content = """---
|
||||
name: body-test
|
||||
description: Test body extraction.
|
||||
---
|
||||
|
||||
## Step 1
|
||||
|
||||
Do this first.
|
||||
|
||||
## Step 2
|
||||
|
||||
Then do this.
|
||||
"""
|
||||
result = parse_skill_md(tmp_skill(content, "body-test"))
|
||||
assert result is not None
|
||||
assert "## Step 1" in result.body
|
||||
assert "## Step 2" in result.body
|
||||
assert "Do this first." in result.body
|
||||
|
||||
def test_location_is_absolute(self, tmp_skill):
|
||||
content = """---
|
||||
name: abs-path
|
||||
description: Check absolute path.
|
||||
---
|
||||
|
||||
Body.
|
||||
"""
|
||||
path = tmp_skill(content, "abs-path")
|
||||
result = parse_skill_md(path)
|
||||
assert result is not None
|
||||
assert Path(result.location).is_absolute()
|
||||
assert Path(result.base_dir).is_absolute()
|
||||
|
||||
def test_nonexistent_file_returns_none(self, tmp_path):
|
||||
result = parse_skill_md(tmp_path / "nonexistent" / "SKILL.md")
|
||||
assert result is None
|
||||
@@ -299,6 +299,66 @@ class TestSubagentExecution:
|
||||
assert "metadata" in result_data
|
||||
assert result_data["metadata"]["agent_id"] == "researcher"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcu_subagent_auto_populates_tools_from_catalog(self, runtime):
|
||||
"""GCU subagent with tools=[] should receive all catalog tools (auto-populate).
|
||||
|
||||
GCU nodes declare tools=[] because the runner expands them at setup time.
|
||||
But _execute_subagent filters by subagent_spec.tools, which is still empty.
|
||||
The fix: when subagent is GCU with no declared tools, include all catalog tools.
|
||||
"""
|
||||
gcu_spec = NodeSpec(
|
||||
id="browser_worker",
|
||||
name="Browser Worker",
|
||||
description="GCU browser subagent",
|
||||
node_type="gcu",
|
||||
output_keys=["result"],
|
||||
tools=[], # Empty — expects auto-population
|
||||
)
|
||||
|
||||
parent_spec = NodeSpec(
|
||||
id="parent",
|
||||
name="Parent",
|
||||
description="Orchestrator",
|
||||
node_type="event_loop",
|
||||
output_keys=["result"],
|
||||
sub_agents=["browser_worker"],
|
||||
)
|
||||
|
||||
spy_llm = MockStreamingLLM(
|
||||
[set_output_scenario("result", "scraped"), text_finish_scenario()]
|
||||
)
|
||||
|
||||
browser_tool = Tool(name="browser_snapshot", description="Snapshot")
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
memory = SharedMemory()
|
||||
scoped = memory.with_permissions(read_keys=[], write_keys=["result"])
|
||||
|
||||
ctx = NodeContext(
|
||||
runtime=runtime,
|
||||
node_id="parent",
|
||||
node_spec=parent_spec,
|
||||
memory=scoped,
|
||||
input_data={},
|
||||
llm=spy_llm,
|
||||
available_tools=[],
|
||||
all_tools=[browser_tool],
|
||||
goal_context="",
|
||||
goal=None,
|
||||
node_registry={"browser_worker": gcu_spec},
|
||||
)
|
||||
|
||||
result = await node._execute_subagent(ctx, "browser_worker", "Scrape example.com")
|
||||
assert result.is_error is False
|
||||
|
||||
# Verify subagent LLM received browser tools from catalog
|
||||
assert spy_llm.stream_calls, "LLM should have been called"
|
||||
first_call_tools = spy_llm.stream_calls[0]["tools"]
|
||||
tool_names = {t.name for t in first_call_tools} if first_call_tools else set()
|
||||
assert "browser_snapshot" in tool_names
|
||||
assert "delegate_to_sub_agent" not in tool_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for nested subagent prevention
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
# SDR Agent
|
||||
|
||||
An AI-powered sales development outreach automation template for [Hive](https://github.com/aden-hive/hive).
|
||||
|
||||
Score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts — all with human review before anything is sent.
|
||||
|
||||
## Overview
|
||||
|
||||
The SDR Agent automates the full outreach pipeline:
|
||||
|
||||
```
|
||||
Intake → Score Contacts → Filter Contacts → Personalize → Send Outreach → Report
|
||||
```
|
||||
|
||||
1. **Intake** — Accept a contact list and outreach goal; confirm strategy with user
|
||||
2. **Score Contacts** — Rank contacts 0–100 using priority factors (alumni, degree, domain, etc.)
|
||||
3. **Filter Contacts** — Detect and skip suspicious/fake profiles (risk score ≥ 7)
|
||||
4. **Personalize** — Generate an 80–120 word personalized message per contact
|
||||
5. **Send Outreach** — Create Gmail drafts for human review (never sends automatically)
|
||||
6. **Report** — Summarize campaign: contacts scored, filtered, drafted
|
||||
|
||||
## Quickstart
|
||||
|
||||
```bash
|
||||
cd examples/templates/sdr_agent
|
||||
|
||||
# Run interactively via TUI
|
||||
python -m sdr_agent tui
|
||||
|
||||
# Run via CLI with a contacts JSON string
|
||||
python -m sdr_agent run \
|
||||
--contacts '[{"name":"Jane Doe","company":"Acme","title":"Engineer","connection_degree":"2nd","is_alumni":true}]' \
|
||||
--goal "coffee chat" \
|
||||
--background "Learning Technologist at UWO" \
|
||||
--max-contacts 20
|
||||
|
||||
# Validate agent structure
|
||||
python -m sdr_agent validate
|
||||
```
|
||||
|
||||
## Contact Schema
|
||||
|
||||
Each contact in your list supports the following fields:
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `name` | string | ✅ | Contact's full name |
|
||||
| `email` | string | ❌ | Email address (draft placeholder if missing) |
|
||||
| `company` | string | ✅ | Current company |
|
||||
| `title` | string | ✅ | Job title |
|
||||
| `linkedin_url` | string | ❌ | LinkedIn profile URL |
|
||||
| `connection_degree` | string | ❌ | `"1st"`, `"2nd"`, or `"3rd"` |
|
||||
| `is_alumni` | boolean | ❌ | Shares school with user |
|
||||
| `school_name` | string | ❌ | School name for alumni messaging |
|
||||
| `connections_count` | integer | ❌ | Number of LinkedIn connections |
|
||||
| `mutual_connections` | integer | ❌ | Count of mutual connections |
|
||||
| `has_photo` | boolean | ❌ | Has a profile photo |
|
||||
|
||||
## Scoring Model
|
||||
|
||||
The `score-contacts` node ranks each contact 0–100:
|
||||
|
||||
| Factor | Points |
|
||||
|--------|--------|
|
||||
| Alumni | +30 |
|
||||
| 1st degree | +25 |
|
||||
| 2nd degree | +20 |
|
||||
| 3rd degree | +10 |
|
||||
| Domain verified | +10 |
|
||||
| Mutual connections (×1, max 10) | +10 |
|
||||
| Active job posting | +10 |
|
||||
| Has profile photo | +5 |
|
||||
| 500+ connections | +5 |
|
||||
|
||||
## Scam Detection
|
||||
|
||||
The `filter-contacts` node calculates a risk score and excludes contacts with risk ≥ 7:
|
||||
|
||||
| Red Flag | Risk |
|
||||
|----------|------|
|
||||
| Fewer than 50 connections | +3 |
|
||||
| No profile photo | +2 |
|
||||
| Fewer than 2 work positions | +2 |
|
||||
| Generic title + few connections | +2 |
|
||||
| Unverifiable company | +2 |
|
||||
| AI-generated-looking profile | +2 |
|
||||
| 5000+ connections, 0 mutual | +1 |
|
||||
|
||||
## Pipeline Output Files
|
||||
|
||||
Each run writes to `~/.hive/agents/sdr_agent/data/`:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| `contacts.jsonl` | Raw contact list |
|
||||
| `scored_contacts.jsonl` | Contacts with `priority_score` |
|
||||
| `safe_contacts.jsonl` | Contacts passing scam filter |
|
||||
| `personalized_contacts.jsonl` | Contacts with `outreach_message` |
|
||||
| `drafts.jsonl` | Draft creation records |
|
||||
|
||||
## Safety Constraints
|
||||
|
||||
- **Never sends emails** — only `gmail_create_draft` is called; human must review and send
|
||||
- **Batch limit** — processes at most `max_contacts` per run (default: 20)
|
||||
- **Skip suspicious** — contacts with `risk_score ≥ 7` are always excluded
|
||||
|
||||
## Tools Required
|
||||
|
||||
- `gmail_create_draft` — create Gmail draft for each contact
|
||||
- `load_data` — read JSONL data files
|
||||
- `append_data` — write to JSONL data files
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ SDR Agent │
|
||||
│ │
|
||||
│ ┌────────┐ ┌───────────────┐ ┌────────────────┐ │
|
||||
│ │ Intake │──▶│ Score Contacts│──▶│ Filter Contacts│ │
|
||||
│ └────────┘ └───────────────┘ └────────────────┘ │
|
||||
│ ▲ │ │
|
||||
│ │ ▼ │
|
||||
│ ┌────────┐ ┌───────────────┐ ┌─────────────┐ │
|
||||
│ │ Report │◀──│ Send Outreach │◀──│ Personalize │ │
|
||||
│ └────────┘ └───────────────┘ └─────────────┘ │
|
||||
│ │
|
||||
│ ● client_facing nodes: intake, report │
|
||||
│ ● automated nodes: score-contacts, filter-contacts, │
|
||||
│ personalize, send-outreach │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Inspiration
|
||||
|
||||
This template is inspired by real-world SDR automation patterns, including contact ranking, scam detection, and two-step personalization (hook extraction → message generation) — demonstrating how job-search and sales outreach workflows can be modeled as AI agent pipelines in Hive.
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
SDR Agent — Automated sales development outreach pipeline.
|
||||
|
||||
Score contacts by priority, filter suspicious profiles, generate personalized
|
||||
outreach messages, and create Gmail drafts for human review before sending.
|
||||
"""
|
||||
|
||||
from .agent import (
|
||||
SDRAgent,
|
||||
default_agent,
|
||||
goal,
|
||||
nodes,
|
||||
edges,
|
||||
loop_config,
|
||||
async_entry_points,
|
||||
entry_node,
|
||||
entry_points,
|
||||
pause_nodes,
|
||||
terminal_nodes,
|
||||
conversation_mode,
|
||||
identity_prompt,
|
||||
)
|
||||
from .config import RuntimeConfig, AgentMetadata, default_config, metadata
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
__all__ = [
|
||||
"SDRAgent",
|
||||
"default_agent",
|
||||
"goal",
|
||||
"nodes",
|
||||
"edges",
|
||||
"loop_config",
|
||||
"async_entry_points",
|
||||
"entry_node",
|
||||
"entry_points",
|
||||
"pause_nodes",
|
||||
"terminal_nodes",
|
||||
"conversation_mode",
|
||||
"identity_prompt",
|
||||
"RuntimeConfig",
|
||||
"AgentMetadata",
|
||||
"default_config",
|
||||
"metadata",
|
||||
]
|
||||
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
CLI entry point for SDR Agent.
|
||||
|
||||
Automates sales development outreach: score contacts, filter suspicious
|
||||
profiles, generate personalized messages, and create Gmail drafts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import click
|
||||
|
||||
from .agent import default_agent, SDRAgent
|
||||
|
||||
|
||||
def setup_logging(verbose=False, debug=False):
|
||||
"""Configure logging for execution visibility."""
|
||||
if debug:
|
||||
level, fmt = logging.DEBUG, "%(asctime)s %(name)s: %(message)s"
|
||||
elif verbose:
|
||||
level, fmt = logging.INFO, "%(message)s"
|
||||
else:
|
||||
level, fmt = logging.WARNING, "%(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=level, format=fmt, stream=sys.stderr)
|
||||
logging.getLogger("framework").setLevel(level)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version="1.0.0")
|
||||
def cli():
|
||||
"""SDR Agent - Automated outreach with contact scoring and personalization."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--contacts",
|
||||
"-c",
|
||||
type=str,
|
||||
required=True,
|
||||
help="JSON string or file path of contacts list",
|
||||
)
|
||||
@click.option(
|
||||
"--goal",
|
||||
"-g",
|
||||
type=str,
|
||||
default="coffee chat",
|
||||
help="Outreach goal (e.g. 'coffee chat', 'sales pitch')",
|
||||
)
|
||||
@click.option(
|
||||
"--background",
|
||||
"-b",
|
||||
type=str,
|
||||
default="",
|
||||
help="Your background/role for personalization",
|
||||
)
|
||||
@click.option(
|
||||
"--max-contacts",
|
||||
"-m",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max contacts to process per batch (default: 20)",
|
||||
)
|
||||
@click.option(
|
||||
"--mock", is_flag=True, help="Run in mock mode without LLM or Gmail calls"
|
||||
)
|
||||
@click.option("--quiet", "-q", is_flag=True, help="Only output result JSON")
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Show execution details")
|
||||
@click.option("--debug", is_flag=True, help="Show debug logging")
|
||||
def run(contacts, goal, background, max_contacts, mock, quiet, verbose, debug):
|
||||
"""Execute an SDR outreach campaign for the given contacts."""
|
||||
if not quiet:
|
||||
setup_logging(verbose=verbose, debug=debug)
|
||||
|
||||
context = {
|
||||
"contacts": contacts,
|
||||
"outreach_goal": goal,
|
||||
"user_background": background,
|
||||
"max_contacts": str(max_contacts),
|
||||
}
|
||||
|
||||
result = asyncio.run(default_agent.run(context, mock_mode=mock))
|
||||
|
||||
output_data = {
|
||||
"success": result.success,
|
||||
"steps_executed": result.steps_executed,
|
||||
"output": result.output,
|
||||
}
|
||||
if result.error:
|
||||
output_data["error"] = result.error
|
||||
|
||||
click.echo(json.dumps(output_data, indent=2, default=str))
|
||||
sys.exit(0 if result.success else 1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--mock", is_flag=True, help="Run in mock mode")
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Show execution details")
|
||||
@click.option("--debug", is_flag=True, help="Show debug logging")
|
||||
def tui(mock, verbose, debug):
|
||||
"""Launch the TUI dashboard for interactive SDR outreach."""
|
||||
setup_logging(verbose=verbose, debug=debug)
|
||||
|
||||
try:
|
||||
from framework.tui.app import AdenTUI
|
||||
except ImportError:
|
||||
click.echo(
|
||||
"TUI requires the 'textual' package. Install with: pip install textual"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
async def run_with_tui():
|
||||
agent = SDRAgent()
|
||||
await agent.start(mock_mode=mock)
|
||||
|
||||
try:
|
||||
app = AdenTUI(agent._agent_runtime)
|
||||
await app.run_async()
|
||||
finally:
|
||||
await agent.stop()
|
||||
|
||||
asyncio.run(run_with_tui())
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--json", "output_json", is_flag=True)
|
||||
def info(output_json):
|
||||
"""Show agent information."""
|
||||
info_data = default_agent.info()
|
||||
if output_json:
|
||||
click.echo(json.dumps(info_data, indent=2))
|
||||
else:
|
||||
click.echo(f"Agent: {info_data['name']}")
|
||||
click.echo(f"Version: {info_data['version']}")
|
||||
click.echo(f"Description: {info_data['description']}")
|
||||
click.echo(f"\nNodes: {', '.join(info_data['nodes'])}")
|
||||
click.echo(f"Client-facing: {', '.join(info_data['client_facing_nodes'])}")
|
||||
click.echo(f"Entry: {info_data['entry_node']}")
|
||||
click.echo(f"Terminal: {', '.join(info_data['terminal_nodes'])}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def validate():
|
||||
"""Validate agent structure."""
|
||||
validation = default_agent.validate()
|
||||
if validation["valid"]:
|
||||
click.echo("Agent is valid")
|
||||
if validation["warnings"]:
|
||||
for warning in validation["warnings"]:
|
||||
click.echo(f" WARNING: {warning}")
|
||||
else:
|
||||
click.echo("Agent has errors:")
|
||||
for error in validation["errors"]:
|
||||
click.echo(f" ERROR: {error}")
|
||||
sys.exit(0 if validation["valid"] else 1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--verbose", "-v", is_flag=True)
|
||||
def shell(verbose):
|
||||
"""Interactive SDR outreach session (CLI, no TUI)."""
|
||||
asyncio.run(_interactive_shell(verbose))
|
||||
|
||||
|
||||
async def _interactive_shell(verbose=False):
|
||||
"""Async interactive shell."""
|
||||
setup_logging(verbose=verbose)
|
||||
|
||||
click.echo("=== SDR Agent ===")
|
||||
click.echo("Automated contact scoring, filtering, and outreach personalization\n")
|
||||
|
||||
agent = SDRAgent()
|
||||
await agent.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
goal = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Outreach goal (e.g. 'coffee chat')> "
|
||||
)
|
||||
if goal.lower() in ["quit", "exit", "q"]:
|
||||
click.echo("Goodbye!")
|
||||
break
|
||||
|
||||
contacts = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Contacts (JSON)> "
|
||||
)
|
||||
background = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Your background/role> "
|
||||
)
|
||||
|
||||
if not contacts.strip():
|
||||
continue
|
||||
|
||||
click.echo("\nRunning SDR campaign...\n")
|
||||
|
||||
result = await agent.trigger_and_wait(
|
||||
"start",
|
||||
{
|
||||
"contacts": contacts,
|
||||
"outreach_goal": goal,
|
||||
"user_background": background,
|
||||
"max_contacts": "20",
|
||||
},
|
||||
)
|
||||
|
||||
if result is None:
|
||||
click.echo("\n[Execution timed out]\n")
|
||||
continue
|
||||
|
||||
if result.success:
|
||||
output = result.output
|
||||
if "summary_report" in output:
|
||||
click.echo("\n--- Campaign Report ---\n")
|
||||
click.echo(output["summary_report"])
|
||||
click.echo("\n")
|
||||
else:
|
||||
click.echo(f"\nCampaign failed: {result.error}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await agent.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,378 @@
|
||||
{
|
||||
"agent": {
|
||||
"id": "sdr_agent",
|
||||
"name": "SDR Agent",
|
||||
"version": "1.0.0",
|
||||
"description": "Automate sales development outreach using AI-powered contact scoring, scam detection, and personalized message generation. Score contacts by priority, filter suspicious profiles, generate personalized outreach messages, and create Gmail drafts for review — all without sending emails automatically."
|
||||
},
|
||||
"graph": {
|
||||
"id": "sdr-agent-graph",
|
||||
"goal_id": "sdr-agent",
|
||||
"version": "1.0.0",
|
||||
"entry_node": "intake",
|
||||
"entry_points": {
|
||||
"start": "intake"
|
||||
},
|
||||
"pause_nodes": [],
|
||||
"terminal_nodes": ["complete"],
|
||||
"conversation_mode": "continuous",
|
||||
"identity_prompt": "You are an SDR (Sales Development Representative) assistant. You help users automate their outreach by scoring contacts, filtering suspicious profiles, generating personalized messages, and creating Gmail drafts — all with human review before anything is sent.",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "intake",
|
||||
"name": "Intake",
|
||||
"description": "Receive the contact list and outreach goal from the user. Confirm the strategy and batch size before proceeding.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are an SDR (Sales Development Representative) assistant helping automate outreach.\n\n**STEP 1 — Respond to the user (text only, NO tool calls):**\n\nRead the user's input from context. Confirm your understanding of:\n- The contact list they provided (or ask them to provide one)\n- Their outreach goal (e.g. \"coffee chat\", \"sales pitch\", \"networking\")\n- Their background/role (used to personalize messages)\n- The batch size (max_contacts). Default to 20 if not specified.\n\nPresent a summary like:\n\"Here's what I'll do:\n1. Score and rank your contacts by priority (alumni status, connection degree, etc.)\n2. Filter out suspicious or low-quality profiles (risk score ≥ 7)\n3. Generate a personalized outreach message for each contact\n4. Create Gmail draft emails for your review — I never send automatically\n\nReady to proceed with [N] contacts for [goal]?\"\n\n**STEP 2 — After the user confirms, call set_output:**\n\n- set_output(\"contacts\", <the contact list as a JSON string>)\n- set_output(\"outreach_goal\", <the confirmed goal, e.g. \"coffee chat\">)\n- set_output(\"max_contacts\", <the confirmed batch size as a string, e.g. \"20\">)\n- set_output(\"user_background\", <user's background/role, e.g. \"Learning Technologist at UWO\">)",
|
||||
"tools": [],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": true,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "score-contacts",
|
||||
"name": "Score Contacts",
|
||||
"description": "Score and rank each contact from 0 to 100 based on priority factors: alumni status, connection degree, domain verification, mutual connections, and active job postings.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are a contact prioritization engine. Score each contact from 0 to 100.\n\n**SCORING RULES (additive):**\n- Alumni of the user's school: +30 points\n- 1st degree connection: +25 points\n- 2nd degree connection: +20 points\n- 3rd degree connection: +10 points\n- Domain verified (company email matches LinkedIn company): +10 points\n- Has mutual connections (1 point each, max 10): up to +10 points\n- Active job posting at their company: +10 points\n- Has a profile photo: +5 points\n- Over 500 connections: +5 points\n\nCap the final score at 100.\n\n**STEP 1 — Load the contacts:**\nCall load_data(filename=\"contacts.jsonl\") to read the contact list.\nIf \"contacts\" in context is a JSON string (not a filename), write it first:\n- For each contact in the list, call append_data(filename=\"contacts.jsonl\", data=<JSON contact object>)\nThen read it back.\n\n**STEP 2 — Score each contact:**\nFor each contact, calculate the priority score using the rules above.\nAdd a \"priority_score\" field to each contact object.\n\n**STEP 3 — Write scored contacts and set output:**\n- Call append_data(filename=\"scored_contacts.jsonl\", data=<JSON contact with priority_score>) for each contact.\n- Sort contacts by priority_score (highest first) in your final output.\n- Call set_output(\"scored_contacts\", \"scored_contacts.jsonl\")",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "filter-contacts",
|
||||
"name": "Filter Contacts",
|
||||
"description": "Analyze each contact for authenticity and filter out suspicious profiles. Any contact with a risk score of 7 or higher is skipped.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"output_keys": [
|
||||
"safe_contacts",
|
||||
"filtered_count"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are a profile authenticity analyzer. Your job is to detect suspicious or fake LinkedIn profiles.\n\n**RISK SCORING RULES (additive):**\n- Fewer than 50 connections: +3 points\n- No profile photo: +2 points\n- Fewer than 2 positions in work history: +2 points\n- Generic title (e.g. \"entrepreneur\", \"CEO\", \"consultant\") AND fewer than 100 connections: +2 points\n- Company name appears generic or unverifiable: +2 points\n- Profile text seems auto-generated or overly promotional: +2 points\n- Connection count over 5000 with no mutual connections: +1 point\n\n**DECISION RULE:**\n- risk_score < 4: SAFE — include in outreach\n- risk_score 4–6: CAUTION — include but flag\n- risk_score ≥ 7: SKIP — exclude from outreach\n\n**STEP 1 — Load scored contacts:**\nCall load_data(filename=<the \"scored_contacts\" value from context>).\nProcess contacts chunk by chunk if has_more=true.\n\n**STEP 2 — Analyze each contact:**\nFor each contact, calculate a risk_score using the rules above.\nDetermine: is_safe (risk_score < 7), recommendation (safe/caution/skip), flags (list of triggered rules).\n\n**STEP 3 — Write safe contacts and set output:**\n- For each contact where risk_score < 7: call append_data(filename=\"safe_contacts.jsonl\", data=<contact JSON with risk_score and flags added>)\n- Track how many contacts were filtered (risk_score ≥ 7)\n- Call set_output(\"safe_contacts\", \"safe_contacts.jsonl\")\n- Call set_output(\"filtered_count\", <number of skipped contacts as string>)",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "personalize",
|
||||
"name": "Personalize",
|
||||
"description": "Generate a personalized outreach message for each contact based on their profile, shared background, and the user's outreach goal.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"safe_contacts",
|
||||
"outreach_goal",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"personalized_contacts"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are a professional outreach message writer. Generate personalized messages for each contact.\n\n**TWO-STEP PERSONALIZATION:**\n\nFor each contact, follow this two-step approach:\n\nSTEP A — Extract hooks (analyze the profile):\nLook for 2-3 specific talking points from the contact's profile:\n- Shared alumni connection\n- Specific role, company, or career transition worth mentioning\n- Any mutual interests aligned with the user's background\n\nSTEP B — Generate the message:\nWrite a warm, professional outreach message using the hooks.\n\n**MESSAGE REQUIREMENTS:**\n- 80-120 words (LinkedIn message length)\n- Start with a specific observation (\"I noticed you...\" or \"Fellow [school] alum here...\")\n- Mention the shared connection or interest naturally\n- State the outreach goal clearly but softly (e.g. \"Open to a brief 15-min chat?\")\n- Professional but warm tone — NOT templated or AI-sounding\n- Do NOT mention job postings directly unless the goal is job-related\n- Do NOT use generic openers like \"I hope this finds you well\"\n- End with a low-pressure ask\n\n**STEP 1 — Load safe contacts:**\nCall load_data(filename=<the \"safe_contacts\" value from context>).\n\n**STEP 2 — Generate message for each contact:**\nFor each contact: generate the personalized message using the two-step approach above.\nAdd \"outreach_message\" field to each contact object.\n\n**STEP 3 — Write output and set:**\n- Call append_data(filename=\"personalized_contacts.jsonl\", data=<contact JSON with outreach_message>) for each.\n- Call set_output(\"personalized_contacts\", \"personalized_contacts.jsonl\")",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "send-outreach",
|
||||
"name": "Send Outreach",
|
||||
"description": "Create Gmail draft emails for each contact using their personalized message. Drafts are created for human review — emails are never sent automatically.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"personalized_contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"drafts_created"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are an outreach execution assistant. Create Gmail draft emails for each contact.\n\n**CRITICAL RULE: NEVER send emails automatically. Only create drafts.**\n\n**STEP 1 — Load personalized contacts:**\nCall load_data(filename=<the \"personalized_contacts\" value from context>).\nProcess chunk by chunk if has_more=true.\n\n**STEP 2 — Create Gmail draft for each contact:**\nFor each contact with an \"outreach_message\":\n- subject: \"Coffee Chat Request\" (or appropriate subject based on outreach_goal)\n- to: contact's email address (use LinkedIn profile URL if email not available — note this in body)\n- body: the \"outreach_message\" from the contact object\n\nCall gmail_create_draft(\n to=<contact email or linkedin_url as placeholder>,\n subject=<appropriate subject line>,\n body=<outreach_message>\n)\n\nRecord each draft: call append_data(\n filename=\"drafts.jsonl\",\n data=<JSON: {contact_name, contact_email, subject, status: \"draft_created\"}>\n)\n\n**STEP 3 — Set output:**\n- Call set_output(\"drafts_created\", \"drafts.jsonl\")\n\n**IMPORTANT:** If a contact has no email address, create the draft with their LinkedIn URL as a placeholder and add a note in the body: \"Note: Please find the recipient's email before sending.\"",
|
||||
"tools": [
|
||||
"gmail_create_draft",
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "report",
|
||||
"name": "Report",
|
||||
"description": "Generate a summary report of the outreach campaign: contacts scored, filtered, messaged, and drafts created. Present to user for review.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"drafts_created",
|
||||
"filtered_count",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"summary_report"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are an SDR assistant. Generate a clear campaign summary report and present it to the user.\n\n**STEP 1 — Load draft records:**\nCall load_data(filename=<the \"drafts_created\" value from context>) to read the draft records.\nIf has_more=true, load additional chunks until all records are loaded.\n\n**STEP 2 — Present the report (text only, NO tool calls):**\n\nPresent a clean summary:\n\n📊 **SDR Campaign Summary — [outreach_goal]**\n\n**Overview:**\n- Total contacts processed: [N]\n- Contacts filtered (suspicious profiles): [filtered_count]\n- Safe contacts messaged: [N - filtered_count]\n- Gmail drafts created: [N]\n\n**Drafts Created:**\nList each draft: Contact Name | Company | Subject\n\n**Next Steps:**\n\"Your Gmail drafts are ready for review. Please:\n1. Open Gmail and review each draft\n2. Personalize further if needed\n3. Send when ready\n\nCampaign complete!\"\n\n**STEP 3 — After the user responds, call set_output:**\n- set_output(\"summary_report\", <the formatted report text>)",
|
||||
"tools": [
|
||||
"load_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": true,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "complete",
|
||||
"name": "Complete",
|
||||
"description": "Terminal node - campaign complete.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"summary_report"
|
||||
],
|
||||
"output_keys": [
|
||||
"final_report"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "Campaign is complete. Set the final output.\n\nCall set_output(\"final_report\", <summary_report value from context>)",
|
||||
"tools": [],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 1,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "intake-to-score",
|
||||
"source": "intake",
|
||||
"target": "score-contacts",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "score-to-filter",
|
||||
"source": "score-contacts",
|
||||
"target": "filter-contacts",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "filter-to-personalize",
|
||||
"source": "filter-contacts",
|
||||
"target": "personalize",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "personalize-to-send",
|
||||
"source": "personalize",
|
||||
"target": "send-outreach",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "send-to-report",
|
||||
"source": "send-outreach",
|
||||
"target": "report",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "report-to-complete",
|
||||
"source": "report",
|
||||
"target": "complete",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
}
|
||||
],
|
||||
"max_steps": 100,
|
||||
"max_retries_per_node": 3,
|
||||
"description": "Automated SDR outreach pipeline: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review."
|
||||
},
|
||||
"goal": {
|
||||
"id": "sdr-agent",
|
||||
"name": "SDR Agent",
|
||||
"description": "Automate sales development outreach: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review.",
|
||||
"status": "draft",
|
||||
"success_criteria": [
|
||||
{
|
||||
"id": "contact-scoring-accuracy",
|
||||
"description": "Contacts are correctly scored and ranked by priority factors (alumni status, connection degree, domain verification)",
|
||||
"metric": "scoring_accuracy",
|
||||
"target": ">=90%",
|
||||
"weight": 0.30,
|
||||
"met": false
|
||||
},
|
||||
{
|
||||
"id": "scam-filter-effectiveness",
|
||||
"description": "Suspicious profiles (risk_score >= 7) are correctly identified and excluded from outreach",
|
||||
"metric": "filter_precision",
|
||||
"target": ">=95%",
|
||||
"weight": 0.25,
|
||||
"met": false
|
||||
},
|
||||
{
|
||||
"id": "message-personalization",
|
||||
"description": "Generated messages reference specific profile details (alumni connection, role, company) and match the outreach goal",
|
||||
"metric": "personalization_score",
|
||||
"target": ">=80%",
|
||||
"weight": 0.30,
|
||||
"met": false
|
||||
},
|
||||
{
|
||||
"id": "draft-creation",
|
||||
"description": "Gmail drafts are created for all safe contacts without errors",
|
||||
"metric": "draft_success_rate",
|
||||
"target": "100%",
|
||||
"weight": 0.15,
|
||||
"met": false
|
||||
}
|
||||
],
|
||||
"constraints": [
|
||||
{
|
||||
"id": "draft-not-send",
|
||||
"description": "Agent creates Gmail drafts but NEVER sends emails automatically",
|
||||
"constraint_type": "hard",
|
||||
"category": "safety",
|
||||
"check": ""
|
||||
},
|
||||
{
|
||||
"id": "respect-batch-limit",
|
||||
"description": "Must not process more contacts than the configured max_contacts parameter",
|
||||
"constraint_type": "hard",
|
||||
"category": "operational",
|
||||
"check": ""
|
||||
},
|
||||
{
|
||||
"id": "skip-suspicious",
|
||||
"description": "Contacts with risk_score >= 7 must be excluded from outreach",
|
||||
"constraint_type": "hard",
|
||||
"category": "safety",
|
||||
"check": ""
|
||||
}
|
||||
],
|
||||
"context": {},
|
||||
"required_capabilities": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"version": "1.0.0",
|
||||
"parent_version": null,
|
||||
"evolution_reason": null
|
||||
},
|
||||
"required_tools": [
|
||||
"gmail_create_draft",
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"metadata": {
|
||||
"node_count": 7,
|
||||
"edge_count": 6
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
"""Agent graph construction for SDR Agent."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.edge import AsyncEntryPointSpec, GraphSpec
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.llm import LiteLLMProvider
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
from .config import default_config, metadata
|
||||
from .nodes import (
|
||||
intake_node,
|
||||
score_contacts_node,
|
||||
filter_contacts_node,
|
||||
personalize_node,
|
||||
send_outreach_node,
|
||||
report_node,
|
||||
)
|
||||
|
||||
# Goal definition
|
||||
goal = Goal(
|
||||
id="sdr-agent",
|
||||
name="SDR Agent",
|
||||
description=(
|
||||
"Automate sales development outreach: score contacts by priority, "
|
||||
"filter suspicious profiles, generate personalized messages, "
|
||||
"and create Gmail drafts for human review."
|
||||
),
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="contact-scoring-accuracy",
|
||||
description=(
|
||||
"Contacts are correctly scored and ranked by priority factors "
|
||||
"(alumni status, connection degree, domain verification)"
|
||||
),
|
||||
metric="scoring_accuracy",
|
||||
target=">=90%",
|
||||
weight=0.30,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="scam-filter-effectiveness",
|
||||
description=(
|
||||
"Suspicious profiles (risk_score >= 7) are correctly identified "
|
||||
"and excluded from outreach"
|
||||
),
|
||||
metric="filter_precision",
|
||||
target=">=95%",
|
||||
weight=0.25,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="message-personalization",
|
||||
description=(
|
||||
"Generated messages reference specific profile details "
|
||||
"(alumni connection, role, company) and match the outreach goal"
|
||||
),
|
||||
metric="personalization_score",
|
||||
target=">=80%",
|
||||
weight=0.30,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="draft-creation",
|
||||
description="Gmail drafts are created for all safe contacts without errors",
|
||||
metric="draft_success_rate",
|
||||
target="100%",
|
||||
weight=0.15,
|
||||
),
|
||||
],
|
||||
constraints=[
|
||||
Constraint(
|
||||
id="draft-not-send",
|
||||
description="Agent creates Gmail drafts but NEVER sends emails automatically",
|
||||
constraint_type="hard",
|
||||
category="safety",
|
||||
),
|
||||
Constraint(
|
||||
id="respect-batch-limit",
|
||||
description="Must not process more contacts than the configured max_contacts parameter",
|
||||
constraint_type="hard",
|
||||
category="operational",
|
||||
),
|
||||
Constraint(
|
||||
id="skip-suspicious",
|
||||
description="Contacts with risk_score >= 7 must be excluded from outreach",
|
||||
constraint_type="hard",
|
||||
category="safety",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Node list
|
||||
nodes = [
|
||||
intake_node,
|
||||
score_contacts_node,
|
||||
filter_contacts_node,
|
||||
personalize_node,
|
||||
send_outreach_node,
|
||||
report_node,
|
||||
]
|
||||
|
||||
# Edge definitions
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="intake-to-score",
|
||||
source="intake",
|
||||
target="score-contacts",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="score-to-filter",
|
||||
source="score-contacts",
|
||||
target="filter-contacts",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="filter-to-personalize",
|
||||
source="filter-contacts",
|
||||
target="personalize",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="personalize-to-send",
|
||||
source="personalize",
|
||||
target="send-outreach",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="send-to-report",
|
||||
source="send-outreach",
|
||||
target="report",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="report-to-intake",
|
||||
source="report",
|
||||
target="intake",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
]
|
||||
|
||||
# Graph configuration
|
||||
entry_node = "intake"
|
||||
entry_points = {"start": "intake"}
|
||||
async_entry_points: list[AsyncEntryPointSpec] = [] # SDR Agent is manually triggered
|
||||
pause_nodes = []
|
||||
terminal_nodes = []
|
||||
loop_config = {
|
||||
"max_iterations": 100,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_tool_result_chars": 8000,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
conversation_mode = "continuous"
|
||||
identity_prompt = (
|
||||
"You are an SDR (Sales Development Representative) assistant. "
|
||||
"You help users automate their outreach by scoring contacts, filtering "
|
||||
"suspicious profiles, generating personalized messages, and creating "
|
||||
"Gmail drafts — all with human review before anything is sent."
|
||||
)
|
||||
|
||||
|
||||
class SDRAgent:
|
||||
"""
|
||||
SDR Agent — 6-node pipeline for automated outreach.
|
||||
|
||||
Flow: intake -> score-contacts -> filter-contacts -> personalize
|
||||
-> send-outreach -> report -> intake (loop)
|
||||
|
||||
Pipeline:
|
||||
1. intake: Receive contact list and outreach goal
|
||||
2. score-contacts: Rank contacts 0-100 by priority factors
|
||||
3. filter-contacts: Remove suspicious profiles (risk >= 7)
|
||||
4. personalize: Generate personalized messages for each contact
|
||||
5. send-outreach: Create Gmail drafts (never sends automatically)
|
||||
6. report: Summarize campaign results and present to user
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config or default_config
|
||||
self.goal = goal
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.entry_node = entry_node
|
||||
self.entry_points = entry_points
|
||||
self.pause_nodes = pause_nodes
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self._agent_runtime: AgentRuntime | None = None
|
||||
self._graph: GraphSpec | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
|
||||
def _build_graph(self) -> GraphSpec:
|
||||
"""Build the GraphSpec."""
|
||||
return GraphSpec(
|
||||
id="sdr-agent-graph",
|
||||
goal_id=self.goal.id,
|
||||
version="1.0.0",
|
||||
entry_node=self.entry_node,
|
||||
entry_points=self.entry_points,
|
||||
terminal_nodes=self.terminal_nodes,
|
||||
pause_nodes=self.pause_nodes,
|
||||
nodes=self.nodes,
|
||||
edges=self.edges,
|
||||
default_model=self.config.model,
|
||||
max_tokens=self.config.max_tokens,
|
||||
loop_config=loop_config,
|
||||
conversation_mode=conversation_mode,
|
||||
identity_prompt=identity_prompt,
|
||||
)
|
||||
|
||||
def _setup(self, mock_mode=False) -> None:
|
||||
"""Set up the agent runtime with sessions, checkpoints, and logging."""
|
||||
self._storage_path = Path.home() / ".hive" / "agents" / "sdr_agent"
|
||||
self._storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._tool_registry = ToolRegistry()
|
||||
|
||||
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
|
||||
if mcp_config_path.exists():
|
||||
self._tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
tools_path = Path(__file__).parent / "tools.py"
|
||||
if tools_path.exists():
|
||||
self._tool_registry.discover_from_module(tools_path)
|
||||
|
||||
if mock_mode:
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
llm = MockLLMProvider()
|
||||
else:
|
||||
llm = LiteLLMProvider(
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
api_base=self.config.api_base,
|
||||
)
|
||||
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
|
||||
self._graph = self._build_graph()
|
||||
|
||||
checkpoint_config = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
async_checkpoint=True,
|
||||
)
|
||||
|
||||
entry_point_specs = [
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
]
|
||||
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self._graph,
|
||||
goal=self.goal,
|
||||
storage_path=self._storage_path,
|
||||
entry_points=entry_point_specs,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
async def start(self, mock_mode=False) -> None:
|
||||
"""Set up and start the agent runtime."""
|
||||
if self._agent_runtime is None:
|
||||
self._setup(mock_mode=mock_mode)
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the agent runtime and clean up."""
|
||||
if self._agent_runtime and self._agent_runtime.is_running:
|
||||
await self._agent_runtime.stop()
|
||||
self._agent_runtime = None
|
||||
|
||||
async def trigger_and_wait(
|
||||
self,
|
||||
entry_point: str,
|
||||
input_data: dict,
|
||||
timeout: float | None = None,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult | None:
|
||||
"""Execute the graph and wait for completion."""
|
||||
if self._agent_runtime is None:
|
||||
raise RuntimeError("Agent not started. Call start() first.")
|
||||
|
||||
return await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point,
|
||||
input_data=input_data,
|
||||
timeout=timeout,
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, context: dict, mock_mode=False, session_state=None
|
||||
) -> ExecutionResult:
|
||||
"""Run the agent (convenience method for single execution)."""
|
||||
await self.start(mock_mode=mock_mode)
|
||||
try:
|
||||
result = await self.trigger_and_wait(
|
||||
"default", context, session_state=session_state
|
||||
)
|
||||
return result or ExecutionResult(success=False, error="Execution timeout")
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
def info(self):
|
||||
"""Get agent information."""
|
||||
return {
|
||||
"name": metadata.name,
|
||||
"version": metadata.version,
|
||||
"description": metadata.description,
|
||||
"goal": {
|
||||
"name": self.goal.name,
|
||||
"description": self.goal.description,
|
||||
},
|
||||
"nodes": [n.id for n in self.nodes],
|
||||
"edges": [e.id for e in self.edges],
|
||||
"entry_node": self.entry_node,
|
||||
"entry_points": self.entry_points,
|
||||
"pause_nodes": self.pause_nodes,
|
||||
"terminal_nodes": self.terminal_nodes,
|
||||
"client_facing_nodes": [n.id for n in self.nodes if n.client_facing],
|
||||
}
|
||||
|
||||
def validate(self):
|
||||
"""Validate agent structure."""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
node_ids = {node.id for node in self.nodes}
|
||||
for edge in self.edges:
|
||||
if edge.source not in node_ids:
|
||||
errors.append(f"Edge {edge.id}: source '{edge.source}' not found")
|
||||
if edge.target not in node_ids:
|
||||
errors.append(f"Edge {edge.id}: target '{edge.target}' not found")
|
||||
|
||||
if self.entry_node not in node_ids:
|
||||
errors.append(f"Entry node '{self.entry_node}' not found")
|
||||
|
||||
for terminal in self.terminal_nodes:
|
||||
if terminal not in node_ids:
|
||||
errors.append(f"Terminal node '{terminal}' not found")
|
||||
|
||||
for ep_id, node_id in self.entry_points.items():
|
||||
if node_id not in node_ids:
|
||||
errors.append(
|
||||
f"Entry point '{ep_id}' references unknown node '{node_id}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
# Create default instance
|
||||
default_agent = SDRAgent()
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Runtime configuration for SDR Agent."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from framework.config import RuntimeConfig
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
name: str = "SDR Agent"
|
||||
version: str = "1.0.0"
|
||||
description: str = (
|
||||
"Automate sales development outreach using AI-powered contact scoring, "
|
||||
"scam detection, and personalized message generation. "
|
||||
"Score contacts by priority, filter suspicious profiles, generate "
|
||||
"personalized outreach messages, and create Gmail drafts for review."
|
||||
)
|
||||
intro_message: str = (
|
||||
"Hi! I'm your SDR (Sales Development Representative) assistant. "
|
||||
"Provide a list of contacts and your outreach goal, and I'll "
|
||||
"score them by priority, filter out suspicious profiles, generate "
|
||||
"personalized messages for each contact, and create Gmail drafts "
|
||||
"for your review. I never send emails automatically — you stay in control. "
|
||||
"To get started, share your contact list and tell me about your outreach goal!"
|
||||
)
|
||||
|
||||
|
||||
metadata = AgentMetadata()
|
||||
@@ -0,0 +1,97 @@
|
||||
[
|
||||
{
|
||||
"name": "Sarah Chen",
|
||||
"email": "sarah.chen@techcorp.io",
|
||||
"company": "TechCorp",
|
||||
"title": "Learning & Development Manager",
|
||||
"linkedin_url": "https://linkedin.com/in/sarah-chen-ld",
|
||||
"connection_degree": "2nd",
|
||||
"is_alumni": true,
|
||||
"school_name": "University of Western Ontario",
|
||||
"connections_count": 843,
|
||||
"mutual_connections": 7,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
},
|
||||
{
|
||||
"name": "James Okafor",
|
||||
"email": "james.okafor@edventure.co",
|
||||
"company": "EdVenture",
|
||||
"title": "Instructional Designer",
|
||||
"linkedin_url": "https://linkedin.com/in/james-okafor-id",
|
||||
"connection_degree": "1st",
|
||||
"is_alumni": false,
|
||||
"connections_count": 621,
|
||||
"mutual_connections": 12,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
},
|
||||
{
|
||||
"name": "Emily Zhao",
|
||||
"email": "emily.zhao@univedu.ca",
|
||||
"company": "UniEdu",
|
||||
"title": "Director of Digital Learning",
|
||||
"linkedin_url": "https://linkedin.com/in/emily-zhao-dl",
|
||||
"connection_degree": "2nd",
|
||||
"is_alumni": true,
|
||||
"school_name": "University of Western Ontario",
|
||||
"connections_count": 1204,
|
||||
"mutual_connections": 3,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true,
|
||||
"active_job_posting": true
|
||||
},
|
||||
{
|
||||
"name": "Marcus Williams",
|
||||
"email": "marcus@growthsales.io",
|
||||
"company": "GrowthSales",
|
||||
"title": "CEO",
|
||||
"linkedin_url": "https://linkedin.com/in/marcus-williams-ceo",
|
||||
"connection_degree": "3rd",
|
||||
"is_alumni": false,
|
||||
"connections_count": 6300,
|
||||
"mutual_connections": 0,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": false
|
||||
},
|
||||
{
|
||||
"name": "Priya Patel",
|
||||
"email": "",
|
||||
"company": "FutureLearn Inc.",
|
||||
"title": "EdTech Product Manager",
|
||||
"linkedin_url": "https://linkedin.com/in/priya-patel-edtech",
|
||||
"connection_degree": "2nd",
|
||||
"is_alumni": false,
|
||||
"connections_count": 512,
|
||||
"mutual_connections": 5,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
},
|
||||
{
|
||||
"name": "Alex Johnson",
|
||||
"email": "alex@bizopp.biz",
|
||||
"company": "Biz Opportunity Global",
|
||||
"title": "Entrepreneur",
|
||||
"linkedin_url": "https://linkedin.com/in/alex-johnson-biz",
|
||||
"connection_degree": "3rd",
|
||||
"is_alumni": false,
|
||||
"connections_count": 38,
|
||||
"mutual_connections": 0,
|
||||
"has_photo": false,
|
||||
"company_domain_verified": false
|
||||
},
|
||||
{
|
||||
"name": "Natalie Brown",
|
||||
"email": "natalie.brown@learningpro.com",
|
||||
"company": "LearningPro",
|
||||
"title": "HR Learning Specialist",
|
||||
"linkedin_url": "https://linkedin.com/in/natalie-brown-hr",
|
||||
"connection_degree": "1st",
|
||||
"is_alumni": true,
|
||||
"school_name": "University of Western Ontario",
|
||||
"connections_count": 389,
|
||||
"mutual_connections": 9,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,270 @@
|
||||
{
|
||||
"original_draft": {
|
||||
"agent_name": "sdr_agent",
|
||||
"goal": "Automate sales development outreach: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review.",
|
||||
"description": "",
|
||||
"success_criteria": [
|
||||
"Contacts are correctly scored and ranked by priority factors (alumni status, connection degree, domain verification)",
|
||||
"Suspicious profiles (risk_score >= 7) are correctly identified and excluded from outreach",
|
||||
"Generated messages reference specific profile details (alumni connection, role, company) and match the outreach goal",
|
||||
"Gmail drafts are created for all safe contacts without errors"
|
||||
],
|
||||
"constraints": [
|
||||
"Agent creates Gmail drafts but NEVER sends emails automatically",
|
||||
"Must not process more contacts than the configured max_contacts parameter",
|
||||
"Contacts with risk_score >= 7 must be excluded from outreach"
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"id": "intake",
|
||||
"name": "Intake",
|
||||
"description": "Receive the contact list and outreach goal from the user. Confirm the strategy and batch size before proceeding.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_contacts_from_file"
|
||||
],
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"success_criteria": "The user has confirmed the contact list, outreach goal, batch size, and their background. All four keys have been written via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "start",
|
||||
"flowchart_shape": "stadium",
|
||||
"flowchart_color": "#8aad3f"
|
||||
},
|
||||
{
|
||||
"id": "score-contacts",
|
||||
"name": "Score Contacts",
|
||||
"description": "Score and rank each contact from 0 to 100 based on priority factors: alumni status, connection degree, domain verification, mutual connections, and active job postings.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"success_criteria": "Every contact has a priority_score field (0-100) and scored_contacts.jsonl has been written and referenced via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "filter-contacts",
|
||||
"name": "Filter Contacts",
|
||||
"description": "Analyze each contact for authenticity and filter out suspicious profiles. Any contact with a risk score of 7 or higher is skipped.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"output_keys": [
|
||||
"safe_contacts",
|
||||
"filtered_count"
|
||||
],
|
||||
"success_criteria": "Each contact has a risk_score and recommendation field. Contacts with risk_score >= 7 are excluded. safe_contacts.jsonl and filtered_count are set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "personalize",
|
||||
"name": "Personalize",
|
||||
"description": "Generate a personalized outreach message for each contact based on their profile, shared background, and the user's outreach goal.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"safe_contacts",
|
||||
"outreach_goal",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"personalized_contacts"
|
||||
],
|
||||
"success_criteria": "Every safe contact has an outreach_message field of 80-120 words that references a specific hook from their profile. personalized_contacts.jsonl is set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "send-outreach",
|
||||
"name": "Send Outreach",
|
||||
"description": "Create Gmail draft emails for each contact using their personalized message. Drafts are created for human review \u2014 emails are never sent automatically.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"gmail_create_draft",
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"personalized_contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"drafts_created"
|
||||
],
|
||||
"success_criteria": "A Gmail draft has been created for every safe contact. drafts.jsonl records each draft and drafts_created is set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "report",
|
||||
"name": "Report",
|
||||
"description": "Generate a summary report of the outreach campaign: contacts scored, filtered, messaged, and drafts created. Present to user for review.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"drafts_created",
|
||||
"filtered_count",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"summary_report"
|
||||
],
|
||||
"success_criteria": "A campaign summary has been presented to the user listing totals for contacts scored, filtered, messaged, and drafts created. summary_report is set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "terminal",
|
||||
"flowchart_shape": "stadium",
|
||||
"flowchart_color": "#b5453a"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "edge-0",
|
||||
"source": "intake",
|
||||
"target": "score-contacts",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-1",
|
||||
"source": "score-contacts",
|
||||
"target": "filter-contacts",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-2",
|
||||
"source": "filter-contacts",
|
||||
"target": "personalize",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-3",
|
||||
"source": "personalize",
|
||||
"target": "send-outreach",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-4",
|
||||
"source": "send-outreach",
|
||||
"target": "report",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-5",
|
||||
"source": "report",
|
||||
"target": "intake",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
}
|
||||
],
|
||||
"entry_node": "intake",
|
||||
"terminal_nodes": [
|
||||
"report"
|
||||
],
|
||||
"flowchart_legend": {
|
||||
"start": {
|
||||
"shape": "stadium",
|
||||
"color": "#8aad3f"
|
||||
},
|
||||
"terminal": {
|
||||
"shape": "stadium",
|
||||
"color": "#b5453a"
|
||||
},
|
||||
"process": {
|
||||
"shape": "rectangle",
|
||||
"color": "#b5a575"
|
||||
},
|
||||
"decision": {
|
||||
"shape": "diamond",
|
||||
"color": "#d89d26"
|
||||
},
|
||||
"io": {
|
||||
"shape": "parallelogram",
|
||||
"color": "#d06818"
|
||||
},
|
||||
"document": {
|
||||
"shape": "document",
|
||||
"color": "#c4b830"
|
||||
},
|
||||
"database": {
|
||||
"shape": "cylinder",
|
||||
"color": "#508878"
|
||||
},
|
||||
"subprocess": {
|
||||
"shape": "subroutine",
|
||||
"color": "#887a48"
|
||||
},
|
||||
"browser": {
|
||||
"shape": "hexagon",
|
||||
"color": "#cc8850"
|
||||
}
|
||||
}
|
||||
},
|
||||
"flowchart_map": {
|
||||
"intake": [
|
||||
"intake"
|
||||
],
|
||||
"score-contacts": [
|
||||
"score-contacts"
|
||||
],
|
||||
"filter-contacts": [
|
||||
"filter-contacts"
|
||||
],
|
||||
"personalize": [
|
||||
"personalize"
|
||||
],
|
||||
"send-outreach": [
|
||||
"send-outreach"
|
||||
],
|
||||
"report": [
|
||||
"report"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"hive-tools": {
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"python",
|
||||
"mcp_server.py",
|
||||
"--stdio"
|
||||
],
|
||||
"cwd": "../../../tools",
|
||||
"description": "Hive tools MCP server"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
"""Node definitions for SDR Agent."""
|
||||
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
# Node 1: Intake (client-facing)
|
||||
# Receives contact list and outreach goal, confirms with user before proceeding.
|
||||
intake_node = NodeSpec(
|
||||
id="intake",
|
||||
name="Intake",
|
||||
description=(
|
||||
"Receive the contact list and outreach goal from the user. "
|
||||
"Confirm the strategy and batch size before proceeding."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=["contacts", "outreach_goal", "max_contacts", "user_background"],
|
||||
output_keys=["contacts", "outreach_goal", "max_contacts", "user_background"],
|
||||
success_criteria=(
|
||||
"The user has confirmed the contact list, outreach goal, batch size, and "
|
||||
"their background. All four keys have been written via set_output."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are an SDR (Sales Development Representative) assistant helping automate outreach.
|
||||
|
||||
**STEP 1 — Understand the input (text only, NO tool calls):**
|
||||
|
||||
Read the user's input from context. Determine what they provided:
|
||||
- If "contacts" is a **file path** (ends in .json or .jsonl), note that you'll load it in step 2.
|
||||
- If "contacts" is a **JSON string**, you'll use it directly.
|
||||
- Identify the outreach goal, background, and batch size (default 20).
|
||||
|
||||
**STEP 2 — Load contacts if needed:**
|
||||
If the user provided a file path for contacts, call:
|
||||
- load_contacts_from_file(file_path=<the path>)
|
||||
This writes the contacts to contacts.jsonl in the session directory.
|
||||
|
||||
**STEP 3 — Confirm with the user (text only, NO tool calls):**
|
||||
|
||||
Present a summary like:
|
||||
"Here's what I'll do:
|
||||
1. Score and rank your contacts by priority (alumni status, connection degree, etc.)
|
||||
2. Filter out suspicious or low-quality profiles (risk score ≥ 7)
|
||||
3. Generate a personalized outreach message for each contact
|
||||
4. Create Gmail draft emails for your review — I never send automatically
|
||||
|
||||
Ready to proceed with [N] contacts for [goal]?"
|
||||
|
||||
**STEP 4 — After the user confirms, call set_output:**
|
||||
|
||||
- set_output("contacts", <the contact list as a JSON string, or "contacts.jsonl" if loaded from file>)
|
||||
- set_output("outreach_goal", <the confirmed goal, e.g. "coffee chat">)
|
||||
- set_output("max_contacts", <the confirmed batch size as a string, e.g. "20">)
|
||||
- set_output("user_background", <user's background/role, e.g. "Learning Technologist at UWO">)
|
||||
""",
|
||||
tools=["load_contacts_from_file"],
|
||||
)
|
||||
|
||||
# Node 2: Score Contacts
|
||||
# Ranks contacts 0-100 based on alumni status, connection degree, domain, etc.
|
||||
score_contacts_node = NodeSpec(
|
||||
id="score-contacts",
|
||||
name="Score Contacts",
|
||||
description=(
|
||||
"Score and rank each contact from 0 to 100 based on priority factors: "
|
||||
"alumni status, connection degree, domain verification, mutual connections, "
|
||||
"and active job postings."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=False,
|
||||
max_node_visits=0,
|
||||
input_keys=["contacts", "outreach_goal"],
|
||||
output_keys=["scored_contacts"],
|
||||
success_criteria=(
|
||||
"Every contact has a priority_score field (0-100) and scored_contacts.jsonl "
|
||||
"has been written and referenced via set_output."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are a contact prioritization engine. Score each contact from 0 to 100.
|
||||
|
||||
**SCORING RULES (additive):**
|
||||
- Alumni of the user's school: +30 points
|
||||
- 1st degree connection: +25 points
|
||||
- 2nd degree connection: +20 points
|
||||
- 3rd degree connection: +10 points
|
||||
- Domain verified (company email matches LinkedIn company): +10 points
|
||||
- Has mutual connections (1 point each, max 10): up to +10 points
|
||||
- Active job posting at their company: +10 points
|
||||
- Has a profile photo: +5 points
|
||||
- Over 500 connections: +5 points
|
||||
|
||||
Cap the final score at 100.
|
||||
|
||||
**STEP 1 — Load the contacts:**
|
||||
Call load_data(filename="contacts.jsonl") to read the contact list.
|
||||
If "contacts" in context is a JSON string (not a filename), write it first:
|
||||
- For each contact in the list, call append_data(filename="contacts.jsonl", data=<JSON contact object>)
|
||||
Then read it back.
|
||||
|
||||
**STEP 2 — Score each contact:**
|
||||
For each contact, calculate the priority score using the rules above.
|
||||
Add a "priority_score" field to each contact object.
|
||||
|
||||
**STEP 3 — Write scored contacts and set output:**
|
||||
- Call append_data(filename="scored_contacts.jsonl", data=<JSON contact with priority_score>) for each contact.
|
||||
- Sort contacts by priority_score (highest first) in your final output.
|
||||
- Call set_output("scored_contacts", "scored_contacts.jsonl")
|
||||
""",
|
||||
tools=["load_data", "append_data"],
|
||||
)
|
||||
|
||||
# Node 3: Filter Contacts (Scam Detection)
|
||||
# Filters out suspicious or fake profiles using a risk scoring system.
|
||||
filter_contacts_node = NodeSpec(
|
||||
id="filter-contacts",
|
||||
name="Filter Contacts",
|
||||
description=(
|
||||
"Analyze each contact for authenticity and filter out suspicious profiles. "
|
||||
"Any contact with a risk score of 7 or higher is skipped."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=False,
|
||||
max_node_visits=0,
|
||||
input_keys=["scored_contacts"],
|
||||
output_keys=["safe_contacts", "filtered_count"],
|
||||
success_criteria=(
|
||||
"Each contact has a risk_score and recommendation field. Contacts with "
|
||||
"risk_score >= 7 are excluded. safe_contacts.jsonl and filtered_count are "
|
||||
"set via set_output."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are a profile authenticity analyzer. Your job is to detect suspicious or fake LinkedIn profiles.
|
||||
|
||||
**RISK SCORING RULES (additive):**
|
||||
- Fewer than 50 connections: +3 points
|
||||
- No profile photo: +2 points
|
||||
- Fewer than 2 positions in work history: +2 points
|
||||
- Generic title (e.g. "entrepreneur", "CEO", "consultant") AND fewer than 100 connections: +2 points
|
||||
- Company name appears generic or unverifiable: +2 points
|
||||
- Profile text seems auto-generated or overly promotional: +2 points
|
||||
- Connection count over 5000 with no mutual connections: +1 point
|
||||
|
||||
**DECISION RULE:**
|
||||
- risk_score < 4: SAFE — include in outreach
|
||||
- risk_score 4–6: CAUTION — include but flag
|
||||
- risk_score ≥ 7: SKIP — exclude from outreach
|
||||
|
||||
**STEP 1 — Load scored contacts:**
|
||||
Call load_data(filename=<the "scored_contacts" value from context>).
|
||||
Process contacts chunk by chunk if has_more=true.
|
||||
|
||||
**STEP 2 — Analyze each contact:**
|
||||
For each contact, calculate a risk_score using the rules above.
|
||||
Determine: is_safe (risk_score < 7), recommendation (safe/caution/skip), flags (list of triggered rules).
|
||||
|
||||
**STEP 3 — Write safe contacts and set output:**
|
||||
- For each contact where risk_score < 7: call append_data(filename="safe_contacts.jsonl", data=<contact JSON with risk_score and flags added>)
|
||||
- Track how many contacts were filtered (risk_score ≥ 7)
|
||||
- Call set_output("safe_contacts", "safe_contacts.jsonl")
|
||||
- Call set_output("filtered_count", <number of skipped contacts as string>)
|
||||
""",
|
||||
tools=["load_data", "append_data"],
|
||||
)
|
||||
|
||||
# Node 4: Personalize Messages
|
||||
# Generates personalized outreach messages for each safe contact.
|
||||
personalize_node = NodeSpec(
|
||||
id="personalize",
|
||||
name="Personalize",
|
||||
description=(
|
||||
"Generate a personalized outreach message for each contact based on "
|
||||
"their profile, shared background, and the user's outreach goal."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=False,
|
||||
max_node_visits=0,
|
||||
input_keys=["safe_contacts", "outreach_goal", "user_background"],
|
||||
output_keys=["personalized_contacts"],
|
||||
success_criteria=(
|
||||
"Every safe contact has an outreach_message field of 80-120 words that "
|
||||
"references a specific hook from their profile. personalized_contacts.jsonl "
|
||||
"is set via set_output."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are a professional outreach message writer. Generate personalized messages for each contact.
|
||||
|
||||
**TWO-STEP PERSONALIZATION:**
|
||||
|
||||
For each contact, follow this two-step approach:
|
||||
|
||||
STEP A — Extract hooks (analyze the profile):
|
||||
Look for 2-3 specific talking points from the contact's profile:
|
||||
- Shared alumni connection
|
||||
- Specific role, company, or career transition worth mentioning
|
||||
- Any mutual interests aligned with the user's background
|
||||
|
||||
STEP B — Generate the message:
|
||||
Write a warm, professional outreach message using the hooks.
|
||||
|
||||
**MESSAGE REQUIREMENTS:**
|
||||
- 80-120 words (LinkedIn message length)
|
||||
- Start with a specific observation ("I noticed you..." or "Fellow [school] alum here...")
|
||||
- Mention the shared connection or interest naturally
|
||||
- State the outreach goal clearly but softly (e.g. "Open to a brief 15-min chat?")
|
||||
- Professional but warm tone — NOT templated or AI-sounding
|
||||
- Do NOT mention job postings directly unless the goal is job-related
|
||||
- Do NOT use generic openers like "I hope this finds you well"
|
||||
- End with a low-pressure ask
|
||||
|
||||
**STEP 1 — Load safe contacts:**
|
||||
Call load_data(filename=<the "safe_contacts" value from context>).
|
||||
|
||||
**STEP 2 — Generate message for each contact:**
|
||||
For each contact: generate the personalized message using the two-step approach above.
|
||||
Add "outreach_message" field to each contact object.
|
||||
|
||||
**STEP 3 — Write output and set:**
|
||||
- Call append_data(filename="personalized_contacts.jsonl", data=<contact JSON with outreach_message>) for each.
|
||||
- Call set_output("personalized_contacts", "personalized_contacts.jsonl")
|
||||
""",
|
||||
tools=["load_data", "append_data"],
|
||||
)
|
||||
|
||||
# Node 5: Send Outreach (Create Gmail Drafts)
|
||||
# Creates Gmail draft emails for each personalized contact. Never sends automatically.
|
||||
send_outreach_node = NodeSpec(
|
||||
id="send-outreach",
|
||||
name="Send Outreach",
|
||||
description=(
|
||||
"Create Gmail draft emails for each contact using their personalized message. "
|
||||
"Drafts are created for human review — emails are never sent automatically."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=False,
|
||||
max_node_visits=0,
|
||||
input_keys=["personalized_contacts", "outreach_goal"],
|
||||
output_keys=["drafts_created"],
|
||||
success_criteria=(
|
||||
"A Gmail draft has been created for every safe contact. "
|
||||
"drafts.jsonl records each draft and drafts_created is set via set_output."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are an outreach execution assistant. Create Gmail draft emails for each contact.
|
||||
|
||||
**CRITICAL RULE: NEVER send emails automatically. Only create drafts.**
|
||||
|
||||
**STEP 1 — Load personalized contacts:**
|
||||
Call load_data(filename=<the "personalized_contacts" value from context>).
|
||||
Process chunk by chunk if has_more=true.
|
||||
|
||||
**STEP 2 — Create Gmail draft for each contact:**
|
||||
For each contact with an "outreach_message":
|
||||
- subject: "Coffee Chat Request" (or appropriate subject based on outreach_goal)
|
||||
- to: contact's email address (use LinkedIn profile URL if email not available — note this in body)
|
||||
- body: the "outreach_message" from the contact object
|
||||
|
||||
Call gmail_create_draft(
|
||||
to=<contact email or linkedin_url as placeholder>,
|
||||
subject=<appropriate subject line>,
|
||||
body=<outreach_message>
|
||||
)
|
||||
|
||||
Record each draft: call append_data(
|
||||
filename="drafts.jsonl",
|
||||
data=<JSON: {contact_name, contact_email, subject, status: "draft_created"}>
|
||||
)
|
||||
|
||||
**STEP 3 — Set output:**
|
||||
- Call set_output("drafts_created", "drafts.jsonl")
|
||||
|
||||
**IMPORTANT:** If a contact has no email address, create the draft with their LinkedIn URL as a placeholder
|
||||
and add a note in the body: "Note: Please find the recipient's email before sending."
|
||||
""",
|
||||
tools=["gmail_create_draft", "load_data", "append_data"],
|
||||
)
|
||||
|
||||
# Node 6: Report (client-facing)
|
||||
# Summarizes results and presents to user for review.
|
||||
report_node = NodeSpec(
|
||||
id="report",
|
||||
name="Report",
|
||||
description=(
|
||||
"Generate a summary report of the outreach campaign: contacts scored, "
|
||||
"filtered, messaged, and drafts created. Present to user for review."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=["drafts_created", "filtered_count", "outreach_goal"],
|
||||
output_keys=["summary_report"],
|
||||
success_criteria=(
|
||||
"A campaign summary has been presented to the user listing totals for "
|
||||
"contacts scored, filtered, messaged, and drafts created. "
|
||||
"summary_report is set via set_output."
|
||||
),
|
||||
system_prompt="""\
|
||||
You are an SDR assistant. Generate a clear campaign summary report and present it to the user.
|
||||
|
||||
**STEP 1 — Load draft records:**
|
||||
Call load_data(filename=<the "drafts_created" value from context>) to read the draft records.
|
||||
If has_more=true, load additional chunks until all records are loaded.
|
||||
|
||||
**STEP 2 — Present the report (text only, NO tool calls):**
|
||||
|
||||
Present a clean summary:
|
||||
|
||||
📊 **SDR Campaign Summary — [outreach_goal]**
|
||||
|
||||
**Overview:**
|
||||
- Total contacts processed: [N]
|
||||
- Contacts filtered (suspicious profiles): [filtered_count]
|
||||
- Safe contacts messaged: [N - filtered_count]
|
||||
- Gmail drafts created: [N]
|
||||
|
||||
**Drafts Created:**
|
||||
List each draft: Contact Name | Company | Subject
|
||||
|
||||
**Next Steps:**
|
||||
"Your Gmail drafts are ready for review. Please:
|
||||
1. Open Gmail and review each draft
|
||||
2. Personalize further if needed
|
||||
3. Send when ready
|
||||
|
||||
Would you like to run another outreach batch or adjust the strategy?"
|
||||
|
||||
**STEP 3 — After the user responds, call set_output:**
|
||||
- set_output("summary_report", <the formatted report text>)
|
||||
""",
|
||||
tools=["load_data"],
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"intake_node",
|
||||
"score_contacts_node",
|
||||
"filter_contacts_node",
|
||||
"personalize_node",
|
||||
"send_outreach_node",
|
||||
"report_node",
|
||||
]
|
||||
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Custom tool functions for SDR Agent.
|
||||
|
||||
Follows the ToolRegistry.discover_from_module() contract:
|
||||
- TOOLS: dict[str, Tool] — tool definitions
|
||||
- tool_executor(tool_use) — unified dispatcher
|
||||
|
||||
These tools provide SDR-specific utilities for loading contact data
|
||||
from a JSON file and writing it to the session's data directory for
|
||||
downstream nodes to process via the standard load_data/append_data tools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
from framework.runner.tool_registry import _execution_context
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions (auto-discovered by ToolRegistry.discover_from_module)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TOOLS = {
|
||||
"load_contacts_from_file": Tool(
|
||||
name="load_contacts_from_file",
|
||||
description=(
|
||||
"Load a contacts JSON file from an absolute or relative path "
|
||||
"and write its contents to contacts.jsonl in the session data directory. "
|
||||
"Returns the number of contacts loaded and the output filename."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Absolute or relative path to a JSON file containing "
|
||||
"a list of contact objects."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_data_dir() -> str:
|
||||
"""Get the session-scoped data_dir from ToolRegistry execution context."""
|
||||
ctx = _execution_context.get()
|
||||
if not ctx or "data_dir" not in ctx:
|
||||
raise RuntimeError(
|
||||
"data_dir not set in execution context. "
|
||||
"Is the tool running inside a GraphExecutor?"
|
||||
)
|
||||
return ctx["data_dir"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_contacts_from_file(file_path: str) -> dict:
|
||||
"""Read a contacts JSON file and write it as contacts.jsonl to data_dir.
|
||||
|
||||
Args:
|
||||
file_path: Path to the contacts JSON file.
|
||||
|
||||
Returns:
|
||||
dict with ``filename`` (always ``"contacts.jsonl"``) and ``count``.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
data_dir = _get_data_dir()
|
||||
Path(data_dir).mkdir(parents=True, exist_ok=True)
|
||||
output_path = Path(data_dir) / "contacts.jsonl"
|
||||
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
contacts = json.load(f)
|
||||
except FileNotFoundError:
|
||||
return {"error": f"File not found: {file_path}"}
|
||||
except json.JSONDecodeError as e:
|
||||
return {"error": f"Invalid JSON: {e}"}
|
||||
|
||||
if not isinstance(contacts, list):
|
||||
contacts = [contacts]
|
||||
|
||||
count = 0
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for contact in contacts:
|
||||
f.write(json.dumps(contact, ensure_ascii=False) + "\n")
|
||||
count += 1
|
||||
|
||||
return {"filename": "contacts.jsonl", "count": count}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unified tool executor (auto-discovered by ToolRegistry.discover_from_module)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
"""Dispatch tool calls to their implementations."""
|
||||
if tool_use.name == "load_contacts_from_file":
|
||||
try:
|
||||
file_path = tool_use.input.get("file_path", "")
|
||||
result = _load_contacts_from_file(file_path=file_path)
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=json.dumps(result),
|
||||
is_error="error" in result,
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=json.dumps({"error": str(e)}),
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}),
|
||||
is_error=True,
|
||||
)
|
||||
+253
-28
@@ -21,6 +21,9 @@ $ErrorActionPreference = "Continue"
|
||||
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Definition
|
||||
$UvHelperPath = Join-Path $ScriptDir "scripts\uv-discovery.ps1"
|
||||
|
||||
# Hive LLM router endpoint
|
||||
$HiveLlmEndpoint = "https://api.adenhq.com"
|
||||
|
||||
. $UvHelperPath
|
||||
|
||||
# ============================================================
|
||||
@@ -775,6 +778,7 @@ $ProviderMap = [ordered]@{
|
||||
GOOGLE_API_KEY = @{ Name = "Google AI"; Id = "google" }
|
||||
GROQ_API_KEY = @{ Name = "Groq"; Id = "groq" }
|
||||
CEREBRAS_API_KEY = @{ Name = "Cerebras"; Id = "cerebras" }
|
||||
OPENROUTER_API_KEY = @{ Name = "OpenRouter"; Id = "openrouter" }
|
||||
MISTRAL_API_KEY = @{ Name = "Mistral"; Id = "mistral" }
|
||||
TOGETHER_API_KEY = @{ Name = "Together AI"; Id = "together" }
|
||||
DEEPSEEK_API_KEY = @{ Name = "DeepSeek"; Id = "deepseek" }
|
||||
@@ -817,9 +821,81 @@ $ModelChoices = @{
|
||||
)
|
||||
}
|
||||
|
||||
function Normalize-OpenRouterModelId {
|
||||
param([string]$ModelId)
|
||||
$normalized = if ($ModelId) { $ModelId.Trim() } else { "" }
|
||||
if ($normalized -match '(?i)^openrouter/(.+)$') {
|
||||
$normalized = $matches[1]
|
||||
}
|
||||
return $normalized
|
||||
}
|
||||
|
||||
function Get-ModelSelection {
|
||||
param([string]$ProviderId)
|
||||
|
||||
if ($ProviderId -eq "openrouter") {
|
||||
$defaultModel = ""
|
||||
if ($PrevModel -and $PrevProvider -eq $ProviderId) {
|
||||
$defaultModel = Normalize-OpenRouterModelId $PrevModel
|
||||
}
|
||||
Write-Host ""
|
||||
Write-Color -Text "Enter your OpenRouter model id:" -Color White
|
||||
Write-Color -Text " Paste from openrouter.ai (example: x-ai/grok-4.20-beta)" -Color DarkGray
|
||||
Write-Color -Text " If calls fail with guardrail/privacy errors: openrouter.ai/settings/privacy" -Color DarkGray
|
||||
Write-Host ""
|
||||
while ($true) {
|
||||
if ($defaultModel) {
|
||||
$rawModel = Read-Host "Model id [$defaultModel]"
|
||||
if ([string]::IsNullOrWhiteSpace($rawModel)) { $rawModel = $defaultModel }
|
||||
} else {
|
||||
$rawModel = Read-Host "Model id"
|
||||
}
|
||||
$normalizedModel = Normalize-OpenRouterModelId $rawModel
|
||||
if (-not [string]::IsNullOrWhiteSpace($normalizedModel)) {
|
||||
$openrouterKey = $null
|
||||
if ($SelectedEnvVar) {
|
||||
$openrouterKey = [System.Environment]::GetEnvironmentVariable($SelectedEnvVar, "Process")
|
||||
if (-not $openrouterKey) {
|
||||
$openrouterKey = [System.Environment]::GetEnvironmentVariable($SelectedEnvVar, "User")
|
||||
}
|
||||
}
|
||||
|
||||
if ($openrouterKey) {
|
||||
Write-Host " Verifying model id... " -NoNewline
|
||||
try {
|
||||
$modelApiBase = if ($SelectedApiBase) { $SelectedApiBase } else { "https://openrouter.ai/api/v1" }
|
||||
$hcResult = & uv run python (Join-Path $ScriptDir "scripts/check_llm_key.py") "openrouter" $openrouterKey $modelApiBase $normalizedModel 2>$null
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
if ($hcJson.model) {
|
||||
$normalizedModel = [string]$hcJson.model
|
||||
}
|
||||
Write-Color -Text "ok" -Color Green
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
Write-Host ""
|
||||
continue
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify model id (network issue). Continuing with your selection." -Color DarkGray
|
||||
}
|
||||
} catch {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify model id (network issue). Continuing with your selection." -Color DarkGray
|
||||
}
|
||||
} else {
|
||||
Write-Color -Text " Skipping model verification (OpenRouter key not available in current shell)." -Color DarkGray
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Ok "Model: $normalizedModel"
|
||||
return @{ Model = $normalizedModel; MaxTokens = 8192; MaxContextTokens = 120000 }
|
||||
}
|
||||
Write-Color -Text "Model id cannot be empty." -Color Red
|
||||
}
|
||||
}
|
||||
|
||||
$choices = $ModelChoices[$ProviderId]
|
||||
if (-not $choices -or $choices.Count -eq 0) {
|
||||
return @{ Model = $DefaultModels[$ProviderId]; MaxTokens = 8192; MaxContextTokens = 120000 }
|
||||
@@ -880,6 +956,7 @@ $SelectedEnvVar = ""
|
||||
$SelectedModel = ""
|
||||
$SelectedMaxTokens = 8192
|
||||
$SelectedMaxContextTokens = 120000
|
||||
$SelectedApiBase = ""
|
||||
$SubscriptionMode = ""
|
||||
|
||||
# ── Credential detection (silent — just set flags) ───────────
|
||||
@@ -903,16 +980,22 @@ $kimiKey = [System.Environment]::GetEnvironmentVariable("KIMI_API_KEY", "User")
|
||||
if (-not $kimiKey) { $kimiKey = $env:KIMI_API_KEY }
|
||||
if ($kimiKey) { $KimiCredDetected = $true }
|
||||
|
||||
$HiveCredDetected = $false
|
||||
$hiveKey = [System.Environment]::GetEnvironmentVariable("HIVE_API_KEY", "User")
|
||||
if (-not $hiveKey) { $hiveKey = $env:HIVE_API_KEY }
|
||||
if ($hiveKey) { $HiveCredDetected = $true }
|
||||
|
||||
# Detect API key providers
|
||||
$ProviderMenuEnvVars = @("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GROQ_API_KEY", "CEREBRAS_API_KEY")
|
||||
$ProviderMenuNames = @("Anthropic (Claude) - Recommended", "OpenAI (GPT)", "Google Gemini - Free tier available", "Groq - Fast, free tier", "Cerebras - Fast, free tier")
|
||||
$ProviderMenuIds = @("anthropic", "openai", "gemini", "groq", "cerebras")
|
||||
$ProviderMenuEnvVars = @("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GROQ_API_KEY", "CEREBRAS_API_KEY", "OPENROUTER_API_KEY")
|
||||
$ProviderMenuNames = @("Anthropic (Claude) - Recommended", "OpenAI (GPT)", "Google Gemini - Free tier available", "Groq - Fast, free tier", "Cerebras - Fast, free tier", "OpenRouter - Bring any OpenRouter model")
|
||||
$ProviderMenuIds = @("anthropic", "openai", "gemini", "groq", "cerebras", "openrouter")
|
||||
$ProviderMenuUrls = @(
|
||||
"https://console.anthropic.com/settings/keys",
|
||||
"https://platform.openai.com/api-keys",
|
||||
"https://aistudio.google.com/apikey",
|
||||
"https://console.groq.com/keys",
|
||||
"https://cloud.cerebras.ai/"
|
||||
"https://cloud.cerebras.ai/",
|
||||
"https://openrouter.ai/keys"
|
||||
)
|
||||
|
||||
# ── Read previous configuration (if any) ──────────────────────
|
||||
@@ -933,6 +1016,7 @@ if (Test-Path $HiveConfigFile) {
|
||||
elseif ($prevLlm.use_kimi_code_subscription) { $PrevSubMode = "kimi_code" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.z.ai*") { $PrevSubMode = "zai_code" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.kimi.com*") { $PrevSubMode = "kimi_code" }
|
||||
elseif ($prevLlm.provider -eq "hive" -or ($prevLlm.api_base -and $prevLlm.api_base -like "*adenhq.com*")) { $PrevSubMode = "hive_llm" }
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
@@ -946,6 +1030,7 @@ if ($PrevSubMode -or $PrevProvider) {
|
||||
"zai_code" { if ($ZaiCredDetected) { $prevCredValid = $true } }
|
||||
"codex" { if ($CodexCredDetected) { $prevCredValid = $true } }
|
||||
"kimi_code" { if ($KimiCredDetected) { $prevCredValid = $true } }
|
||||
"hive_llm" { if ($HiveCredDetected) { $prevCredValid = $true } }
|
||||
default {
|
||||
if ($PrevEnvVar) {
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "Process")
|
||||
@@ -960,14 +1045,16 @@ if ($PrevSubMode -or $PrevProvider) {
|
||||
"zai_code" { $DefaultChoice = "2" }
|
||||
"codex" { $DefaultChoice = "3" }
|
||||
"kimi_code" { $DefaultChoice = "4" }
|
||||
"hive_llm" { $DefaultChoice = "5" }
|
||||
}
|
||||
if (-not $DefaultChoice) {
|
||||
switch ($PrevProvider) {
|
||||
"anthropic" { $DefaultChoice = "5" }
|
||||
"openai" { $DefaultChoice = "6" }
|
||||
"gemini" { $DefaultChoice = "7" }
|
||||
"groq" { $DefaultChoice = "8" }
|
||||
"cerebras" { $DefaultChoice = "9" }
|
||||
"anthropic" { $DefaultChoice = "6" }
|
||||
"openai" { $DefaultChoice = "7" }
|
||||
"gemini" { $DefaultChoice = "8" }
|
||||
"groq" { $DefaultChoice = "9" }
|
||||
"cerebras" { $DefaultChoice = "10" }
|
||||
"openrouter" { $DefaultChoice = "11" }
|
||||
"kimi" { $DefaultChoice = "4" }
|
||||
}
|
||||
}
|
||||
@@ -1007,12 +1094,19 @@ Write-Host ") Kimi Code Subscription " -NoNewline
|
||||
Write-Color -Text "(use your Kimi Code plan)" -Color DarkGray -NoNewline
|
||||
if ($KimiCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
# 5) Hive LLM
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "5" -Color Cyan -NoNewline
|
||||
Write-Host ") Hive LLM " -NoNewline
|
||||
Write-Color -Text "(use your Hive API key)" -Color DarkGray -NoNewline
|
||||
if ($HiveCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
Write-Host ""
|
||||
Write-Color -Text " API key providers:" -Color Cyan
|
||||
|
||||
# 5-9) API key providers
|
||||
# 6-11) API key providers
|
||||
for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
|
||||
$num = $idx + 5
|
||||
$num = $idx + 6
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "Process")
|
||||
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "User") }
|
||||
Write-Host " " -NoNewline
|
||||
@@ -1021,8 +1115,9 @@ for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
|
||||
if ($envVal) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
}
|
||||
|
||||
$SkipChoice = 6 + $ProviderMenuEnvVars.Count
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "10" -Color Cyan -NoNewline
|
||||
Write-Color -Text "$SkipChoice" -Color Cyan -NoNewline
|
||||
Write-Host ") Skip for now"
|
||||
Write-Host ""
|
||||
|
||||
@@ -1033,16 +1128,16 @@ if ($DefaultChoice) {
|
||||
|
||||
while ($true) {
|
||||
if ($DefaultChoice) {
|
||||
$raw = Read-Host "Enter choice (1-10) [$DefaultChoice]"
|
||||
$raw = Read-Host "Enter choice (1-$SkipChoice) [$DefaultChoice]"
|
||||
if ([string]::IsNullOrWhiteSpace($raw)) { $raw = $DefaultChoice }
|
||||
} else {
|
||||
$raw = Read-Host "Enter choice (1-10)"
|
||||
$raw = Read-Host "Enter choice (1-$SkipChoice)"
|
||||
}
|
||||
if ($raw -match '^\d+$') {
|
||||
$num = [int]$raw
|
||||
if ($num -ge 1 -and $num -le 10) { break }
|
||||
if ($num -ge 1 -and $num -le $SkipChoice) { break }
|
||||
}
|
||||
Write-Color -Text "Invalid choice. Please enter 1-10" -Color Red
|
||||
Write-Color -Text "Invalid choice. Please enter 1-$SkipChoice" -Color Red
|
||||
}
|
||||
|
||||
switch ($num) {
|
||||
@@ -1121,13 +1216,42 @@ switch ($num) {
|
||||
Write-Ok "Using Kimi Code subscription"
|
||||
Write-Color -Text " Model: kimi-k2.5 | API: api.kimi.com/coding" -Color DarkGray
|
||||
}
|
||||
{ $_ -ge 5 -and $_ -le 9 } {
|
||||
5 {
|
||||
# Hive LLM
|
||||
$SubscriptionMode = "hive_llm"
|
||||
$SelectedProviderId = "hive"
|
||||
$SelectedEnvVar = "HIVE_API_KEY"
|
||||
$SelectedMaxTokens = 32768
|
||||
$SelectedMaxContextTokens = 120000
|
||||
Write-Host ""
|
||||
Write-Ok "Using Hive LLM"
|
||||
Write-Host ""
|
||||
Write-Host " Select a model:"
|
||||
Write-Host " " -NoNewline; Write-Color -Text "1)" -Color Cyan -NoNewline; Write-Host " queen " -NoNewline; Write-Color -Text "(default - Hive flagship)" -Color DarkGray
|
||||
Write-Host " " -NoNewline; Write-Color -Text "2)" -Color Cyan -NoNewline; Write-Host " kimi-2.5"
|
||||
Write-Host " " -NoNewline; Write-Color -Text "3)" -Color Cyan -NoNewline; Write-Host " GLM-5"
|
||||
Write-Host ""
|
||||
$hiveModelChoice = Read-Host " Enter model choice (1-3) [1]"
|
||||
if (-not $hiveModelChoice) { $hiveModelChoice = "1" }
|
||||
switch ($hiveModelChoice) {
|
||||
"2" { $SelectedModel = "kimi-2.5" }
|
||||
"3" { $SelectedModel = "GLM-5" }
|
||||
default { $SelectedModel = "queen" }
|
||||
}
|
||||
Write-Color -Text " Model: $SelectedModel | API: $HiveLlmEndpoint" -Color DarkGray
|
||||
}
|
||||
{ $_ -ge 6 -and $_ -le 11 } {
|
||||
# API key providers
|
||||
$provIdx = $num - 5
|
||||
$provIdx = $num - 6
|
||||
$SelectedEnvVar = $ProviderMenuEnvVars[$provIdx]
|
||||
$SelectedProviderId = $ProviderMenuIds[$provIdx]
|
||||
$providerName = $ProviderMenuNames[$provIdx] -replace ' - .*', '' # strip description
|
||||
$signupUrl = $ProviderMenuUrls[$provIdx]
|
||||
if ($SelectedProviderId -eq "openrouter") {
|
||||
$SelectedApiBase = "https://openrouter.ai/api/v1"
|
||||
} else {
|
||||
$SelectedApiBase = ""
|
||||
}
|
||||
|
||||
# Prompt for key (allow replacement if already set) with verification + retry
|
||||
while ($true) {
|
||||
@@ -1156,7 +1280,11 @@ switch ($num) {
|
||||
# Health check the new key
|
||||
Write-Host " Verifying API key... " -NoNewline
|
||||
try {
|
||||
$hcResult = & $UvCmd run python (Join-Path $ScriptDir "scripts/check_llm_key.py") $SelectedProviderId $apiKey 2>$null
|
||||
if ($SelectedApiBase) {
|
||||
$hcResult = & uv run python (Join-Path $ScriptDir "scripts/check_llm_key.py") $SelectedProviderId $apiKey $SelectedApiBase 2>$null
|
||||
} else {
|
||||
$hcResult = & uv run python (Join-Path $ScriptDir "scripts/check_llm_key.py") $SelectedProviderId $apiKey 2>$null
|
||||
}
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
Write-Color -Text "ok" -Color Green
|
||||
@@ -1194,7 +1322,7 @@ switch ($num) {
|
||||
}
|
||||
}
|
||||
}
|
||||
10 {
|
||||
{ $_ -eq $SkipChoice } {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. An LLM API key is required to test and use worker agents."
|
||||
Write-Host " Add your API key later by running:"
|
||||
@@ -1335,6 +1463,70 @@ if ($SubscriptionMode -eq "kimi_code") {
|
||||
}
|
||||
}
|
||||
|
||||
# For Hive LLM: prompt for API key with verification + retry
|
||||
if ($SubscriptionMode -eq "hive_llm") {
|
||||
while ($true) {
|
||||
$existingHive = [System.Environment]::GetEnvironmentVariable("HIVE_API_KEY", "User")
|
||||
if (-not $existingHive) { $existingHive = $env:HIVE_API_KEY }
|
||||
|
||||
if ($existingHive) {
|
||||
$masked = $existingHive.Substring(0, [Math]::Min(4, $existingHive.Length)) + "..." + $existingHive.Substring([Math]::Max(0, $existingHive.Length - 4))
|
||||
Write-Host ""
|
||||
Write-Color -Text " $([char]0x2B22) Current Hive key: $masked" -Color Green
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste a new Hive API key (or press Enter to keep current)"
|
||||
} else {
|
||||
Write-Host ""
|
||||
Write-Host " Get your API key from: " -NoNewline
|
||||
Write-Color -Text "https://discord.com/invite/hQdU7QDkgR" -Color Cyan
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste your Hive API key (or press Enter to skip)"
|
||||
}
|
||||
|
||||
if ($apiKey) {
|
||||
[System.Environment]::SetEnvironmentVariable("HIVE_API_KEY", $apiKey, "User")
|
||||
$env:HIVE_API_KEY = $apiKey
|
||||
Write-Host ""
|
||||
Write-Ok "Hive API key saved as User environment variable"
|
||||
|
||||
# Health check the new key
|
||||
Write-Host " Verifying Hive API key... " -NoNewline
|
||||
try {
|
||||
$hcOutput = & $PythonCmd scripts/check_llm_key.py hive $apiKey "$HiveLlmEndpoint" 2>&1
|
||||
$hcJson = $hcOutput | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
Write-Color -Text "ok" -Color Green
|
||||
break
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
[System.Environment]::SetEnvironmentVariable("HIVE_API_KEY", $null, "User")
|
||||
Remove-Item -Path "Env:\HIVE_API_KEY" -ErrorAction SilentlyContinue
|
||||
Write-Host ""
|
||||
Read-Host " Press Enter to try again"
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} catch {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
break
|
||||
}
|
||||
} elseif (-not $existingHive) {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. Add your Hive API key later:"
|
||||
Write-Color -Text " [System.Environment]::SetEnvironmentVariable('HIVE_API_KEY', 'your-key', 'User')" -Color Cyan
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedProviderId = ""
|
||||
$SubscriptionMode = ""
|
||||
break
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Prompt for model if not already selected (manual provider path)
|
||||
if ($SelectedProviderId -and -not $SelectedModel) {
|
||||
$modelSel = Get-ModelSelection $SelectedProviderId
|
||||
@@ -1375,6 +1567,12 @@ if ($SelectedProviderId) {
|
||||
} elseif ($SubscriptionMode -eq "kimi_code") {
|
||||
$config.llm["api_base"] = "https://api.kimi.com/coding"
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
} elseif ($SubscriptionMode -eq "hive_llm") {
|
||||
$config.llm["api_base"] = $HiveLlmEndpoint
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
} elseif ($SelectedProviderId -eq "openrouter") {
|
||||
$config.llm["api_base"] = "https://openrouter.ai/api/v1"
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
} else {
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
}
|
||||
@@ -1674,6 +1872,9 @@ if ($SelectedProviderId) {
|
||||
Write-Color -Text " API: api.z.ai (OpenAI-compatible)" -Color DarkGray
|
||||
} elseif ($SubscriptionMode -eq "codex") {
|
||||
Write-Ok "OpenAI Codex Subscription -> $SelectedModel"
|
||||
} elseif ($SelectedProviderId -eq "openrouter") {
|
||||
Write-Ok "OpenRouter API Key -> $SelectedModel"
|
||||
Write-Color -Text " API: openrouter.ai/api/v1 (OpenAI-compatible)" -Color DarkGray
|
||||
} else {
|
||||
Write-Color -Text " $SelectedProviderId" -Color Cyan -NoNewline
|
||||
Write-Host " -> " -NoNewline
|
||||
@@ -1704,14 +1905,39 @@ if ($CodexAvailable) {
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
# Auto-launch dashboard or show manual instructions
|
||||
# Setup-only mode: show manual instructions
|
||||
if ($FrontendBuilt) {
|
||||
Write-Color -Text "Launching dashboard..." -Color White
|
||||
Write-Color -Text "â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Color -Text " Starting server on http://localhost:8787" -Color DarkGray
|
||||
Write-Color -Text " Press Ctrl+C to stop" -Color DarkGray
|
||||
Write-Color -Text " IMPORTANT: Restart your terminal now!" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Color -Text "â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Host 'Environment variables (uv, API keys) are now configured, but you need to'
|
||||
Write-Host 'restart your terminal for them to take effect in new sessions.'
|
||||
Write-Host ""
|
||||
|
||||
Write-Color -Text "Run an Agent:" -Color White
|
||||
Write-Host ""
|
||||
Write-Host " Quickstart only sets things up. Launch the dashboard when you're ready:"
|
||||
Write-Color -Text " hive open" -Color Cyan
|
||||
Write-Host ""
|
||||
|
||||
if ($SelectedProviderId -or $credKey) {
|
||||
Write-Color -Text "Note:" -Color White
|
||||
Write-Host "- uv has been added to your User PATH"
|
||||
if ($SelectedProviderId -and $SelectedEnvVar) {
|
||||
Write-Host "- $SelectedEnvVar is set for LLM access"
|
||||
}
|
||||
if ($credKey) {
|
||||
Write-Host "- HIVE_CREDENTIAL_KEY is set for credential encryption"
|
||||
}
|
||||
Write-Host "- All variables will persist across reboots"
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
Write-Color -Text 'Run .\quickstart.ps1 again to reconfigure.' -Color DarkGray
|
||||
Write-Host ""
|
||||
& (Join-Path $ScriptDir "hive.ps1") open
|
||||
} else {
|
||||
Write-Color -Text "═══════════════════════════════════════════════════════" -Color Yellow
|
||||
Write-Host ""
|
||||
@@ -1725,9 +1951,8 @@ if ($FrontendBuilt) {
|
||||
|
||||
Write-Color -Text "Run an Agent:" -Color White
|
||||
Write-Host ""
|
||||
Write-Host " Launch the interactive dashboard to browse and run agents:"
|
||||
Write-Host " You can start an example agent or an agent built by yourself:"
|
||||
Write-Color -Text " .\hive.ps1 tui" -Color Cyan
|
||||
Write-Host " Frontend build was skipped or failed. Once the dashboard is available, launch it with:"
|
||||
Write-Color -Text " hive open" -Color Cyan
|
||||
Write-Host ""
|
||||
|
||||
if ($SelectedProviderId -or $credKey) {
|
||||
|
||||
+329
-119
@@ -32,6 +32,9 @@ NC='\033[0m' # No Color
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
|
||||
# Hive LLM router endpoint
|
||||
HIVE_LLM_ENDPOINT="https://api.adenhq.com"
|
||||
|
||||
# Helper function for prompts
|
||||
prompt_yes_no() {
|
||||
local prompt="$1"
|
||||
@@ -43,7 +46,6 @@ prompt_yes_no() {
|
||||
else
|
||||
prompt="$prompt [y/N] "
|
||||
fi
|
||||
|
||||
read -r -p "$prompt" response
|
||||
response="${response:-$default}"
|
||||
[[ "$response" =~ ^[Yy] ]]
|
||||
@@ -371,6 +373,7 @@ if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
["GOOGLE_API_KEY"]="Google AI"
|
||||
["GROQ_API_KEY"]="Groq"
|
||||
["CEREBRAS_API_KEY"]="Cerebras"
|
||||
["OPENROUTER_API_KEY"]="OpenRouter"
|
||||
["MISTRAL_API_KEY"]="Mistral"
|
||||
["TOGETHER_API_KEY"]="Together AI"
|
||||
["DEEPSEEK_API_KEY"]="DeepSeek"
|
||||
@@ -384,6 +387,7 @@ if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
["GOOGLE_API_KEY"]="google"
|
||||
["GROQ_API_KEY"]="groq"
|
||||
["CEREBRAS_API_KEY"]="cerebras"
|
||||
["OPENROUTER_API_KEY"]="openrouter"
|
||||
["MISTRAL_API_KEY"]="mistral"
|
||||
["TOGETHER_API_KEY"]="together"
|
||||
["DEEPSEEK_API_KEY"]="deepseek"
|
||||
@@ -507,9 +511,9 @@ if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
}
|
||||
else
|
||||
# Bash 3.2 - use parallel indexed arrays
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY MINIMAX_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "MiniMax" "Google Gemini" "Google AI" "Groq" "Cerebras" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai minimax gemini google groq cerebras mistral together deepseek)
|
||||
PROVIDER_ENV_VARS=(ANTHROPIC_API_KEY OPENAI_API_KEY MINIMAX_API_KEY GEMINI_API_KEY GOOGLE_API_KEY GROQ_API_KEY CEREBRAS_API_KEY OPENROUTER_API_KEY MISTRAL_API_KEY TOGETHER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_DISPLAY_NAMES=("Anthropic (Claude)" "OpenAI (GPT)" "MiniMax" "Google Gemini" "Google AI" "Groq" "Cerebras" "OpenRouter" "Mistral" "Together AI" "DeepSeek")
|
||||
PROVIDER_ID_LIST=(anthropic openai minimax gemini google groq cerebras openrouter mistral together deepseek)
|
||||
|
||||
# Default models by provider id (parallel arrays)
|
||||
MODEL_PROVIDER_IDS=(anthropic openai minimax gemini groq cerebras mistral together_ai deepseek)
|
||||
@@ -687,10 +691,91 @@ detect_shell_rc() {
|
||||
SHELL_RC_FILE=$(detect_shell_rc)
|
||||
SHELL_NAME=$(basename "$SHELL")
|
||||
|
||||
# Normalize user-pasted OpenRouter model IDs:
|
||||
# - trim whitespace
|
||||
# - strip leading "openrouter/" if present
|
||||
normalize_openrouter_model_id() {
|
||||
local raw="$1"
|
||||
# Trim leading/trailing whitespace
|
||||
raw="${raw#"${raw%%[![:space:]]*}"}"
|
||||
raw="${raw%"${raw##*[![:space:]]}"}"
|
||||
if [[ "$raw" =~ ^[Oo][Pp][Ee][Nn][Rr][Oo][Uu][Tt][Ee][Rr]/(.+)$ ]]; then
|
||||
raw="${BASH_REMATCH[1]}"
|
||||
fi
|
||||
printf '%s' "$raw"
|
||||
}
|
||||
|
||||
# Prompt the user to choose a model for their selected provider.
|
||||
# Sets SELECTED_MODEL, SELECTED_MAX_TOKENS, and SELECTED_MAX_CONTEXT_TOKENS.
|
||||
prompt_model_selection() {
|
||||
local provider_id="$1"
|
||||
|
||||
if [ "$provider_id" = "openrouter" ]; then
|
||||
local default_model=""
|
||||
if [ -n "$PREV_MODEL" ] && [ "$provider_id" = "$PREV_PROVIDER" ]; then
|
||||
default_model="$(normalize_openrouter_model_id "$PREV_MODEL")"
|
||||
fi
|
||||
echo ""
|
||||
echo -e "${BOLD}Enter your OpenRouter model id:${NC}"
|
||||
echo -e " ${DIM}Paste from openrouter.ai (example: x-ai/grok-4.20-beta)${NC}"
|
||||
echo -e " ${DIM}If calls fail with guardrail/privacy errors: openrouter.ai/settings/privacy${NC}"
|
||||
echo ""
|
||||
local input_model=""
|
||||
while true; do
|
||||
if [ -n "$default_model" ]; then
|
||||
read -r -p "Model id [$default_model]: " input_model || true
|
||||
input_model="${input_model:-$default_model}"
|
||||
else
|
||||
read -r -p "Model id: " input_model || true
|
||||
fi
|
||||
local normalized_model
|
||||
normalized_model="$(normalize_openrouter_model_id "$input_model")"
|
||||
if [ -n "$normalized_model" ]; then
|
||||
local openrouter_key=""
|
||||
if [ -n "${SELECTED_ENV_VAR:-}" ]; then
|
||||
openrouter_key="${!SELECTED_ENV_VAR:-}"
|
||||
fi
|
||||
|
||||
if [ -n "$openrouter_key" ]; then
|
||||
local model_hc_result=""
|
||||
local model_hc_valid=""
|
||||
local model_hc_msg=""
|
||||
local model_hc_canonical=""
|
||||
local model_hc_base="${SELECTED_API_BASE:-https://openrouter.ai/api/v1}"
|
||||
echo -n " Verifying model id... "
|
||||
model_hc_result="$(uv run python "$SCRIPT_DIR/scripts/check_llm_key.py" "openrouter" "$openrouter_key" "$model_hc_base" "$normalized_model" 2>/dev/null)" || true
|
||||
model_hc_valid="$(echo "$model_hc_result" | $PYTHON_CMD -c "import json,sys; print(json.loads(sys.stdin.read()).get('valid',''))" 2>/dev/null)" || true
|
||||
model_hc_msg="$(echo "$model_hc_result" | $PYTHON_CMD -c "import json,sys; print(json.loads(sys.stdin.read()).get('message',''))" 2>/dev/null)" || true
|
||||
model_hc_canonical="$(echo "$model_hc_result" | $PYTHON_CMD -c "import json,sys; print(json.loads(sys.stdin.read()).get('model',''))" 2>/dev/null)" || true
|
||||
if [ "$model_hc_valid" = "True" ]; then
|
||||
if [ -n "$model_hc_canonical" ]; then
|
||||
normalized_model="$model_hc_canonical"
|
||||
fi
|
||||
echo -e "${GREEN}ok${NC}"
|
||||
elif [ "$model_hc_valid" = "False" ]; then
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e " ${YELLOW}⚠ $model_hc_msg${NC}"
|
||||
echo ""
|
||||
continue
|
||||
else
|
||||
echo -e "${YELLOW}--${NC}"
|
||||
echo -e " ${DIM}Could not verify model id (network issue). Continuing with your selection.${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e " ${DIM}Skipping model verification (OpenRouter key not available in current shell).${NC}"
|
||||
fi
|
||||
|
||||
SELECTED_MODEL="$normalized_model"
|
||||
SELECTED_MAX_TOKENS=8192
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Model: ${DIM}$SELECTED_MODEL${NC}"
|
||||
return
|
||||
fi
|
||||
echo -e "${RED}Model id cannot be empty.${NC}"
|
||||
done
|
||||
fi
|
||||
|
||||
local count
|
||||
count="$(get_model_choice_count "$provider_id")"
|
||||
|
||||
@@ -784,34 +869,73 @@ save_configuration() {
|
||||
max_context_tokens=120000
|
||||
fi
|
||||
|
||||
mkdir -p "$HIVE_CONFIG_DIR"
|
||||
|
||||
$PYTHON_CMD -c "
|
||||
uv run python - \
|
||||
"$provider_id" \
|
||||
"$env_var" \
|
||||
"$model" \
|
||||
"$max_tokens" \
|
||||
"$max_context_tokens" \
|
||||
"$use_claude_code_sub" \
|
||||
"$api_base" \
|
||||
"$use_codex_sub" \
|
||||
"$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")" 2>/dev/null <<'PY'
|
||||
import json
|
||||
config = {
|
||||
'llm': {
|
||||
'provider': '$provider_id',
|
||||
'model': '$model',
|
||||
'max_tokens': $max_tokens,
|
||||
'max_context_tokens': $max_context_tokens,
|
||||
'api_key_env_var': '$env_var'
|
||||
},
|
||||
'created_at': '$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")'
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
(
|
||||
provider_id,
|
||||
env_var,
|
||||
model,
|
||||
max_tokens,
|
||||
max_context_tokens,
|
||||
use_claude_code_sub,
|
||||
api_base,
|
||||
use_codex_sub,
|
||||
created_at,
|
||||
) = sys.argv[1:10]
|
||||
|
||||
cfg_path = Path.home() / ".hive" / "configuration.json"
|
||||
cfg_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(cfg_path, encoding="utf-8-sig") as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
config = {}
|
||||
|
||||
config["llm"] = {
|
||||
"provider": provider_id,
|
||||
"model": model,
|
||||
"max_tokens": int(max_tokens),
|
||||
"max_context_tokens": int(max_context_tokens),
|
||||
"api_key_env_var": env_var,
|
||||
}
|
||||
if '$use_claude_code_sub' == 'true':
|
||||
config['llm']['use_claude_code_subscription'] = True
|
||||
# No api_key_env_var needed for Claude Code subscription
|
||||
config['llm'].pop('api_key_env_var', None)
|
||||
if '$use_codex_sub' == 'true':
|
||||
config['llm']['use_codex_subscription'] = True
|
||||
# No api_key_env_var needed for Codex subscription
|
||||
config['llm'].pop('api_key_env_var', None)
|
||||
if '$api_base':
|
||||
config['llm']['api_base'] = '$api_base'
|
||||
with open('$HIVE_CONFIG_FILE', 'w') as f:
|
||||
config["created_at"] = created_at
|
||||
|
||||
if use_claude_code_sub == "true":
|
||||
config["llm"]["use_claude_code_subscription"] = True
|
||||
config["llm"].pop("api_key_env_var", None)
|
||||
else:
|
||||
config["llm"].pop("use_claude_code_subscription", None)
|
||||
|
||||
if use_codex_sub == "true":
|
||||
config["llm"]["use_codex_subscription"] = True
|
||||
config["llm"].pop("api_key_env_var", None)
|
||||
else:
|
||||
config["llm"].pop("use_codex_subscription", None)
|
||||
|
||||
if api_base:
|
||||
config["llm"]["api_base"] = api_base
|
||||
else:
|
||||
config["llm"].pop("api_base", None)
|
||||
|
||||
tmp_path = cfg_path.with_name(cfg_path.name + ".tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
tmp_path.replace(cfg_path)
|
||||
print(json.dumps(config, indent=2))
|
||||
" 2>/dev/null
|
||||
PY
|
||||
}
|
||||
|
||||
# Source shell rc file to pick up existing env vars (temporarily disable set -e)
|
||||
@@ -864,6 +988,11 @@ elif [ -n "${KIMI_API_KEY:-}" ]; then
|
||||
KIMI_CRED_DETECTED=true
|
||||
fi
|
||||
|
||||
HIVE_CRED_DETECTED=false
|
||||
if [ -n "${HIVE_API_KEY:-}" ]; then
|
||||
HIVE_CRED_DETECTED=true
|
||||
fi
|
||||
|
||||
# Detect API key providers
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
for env_var in "${!PROVIDER_NAMES[@]}"; do
|
||||
@@ -887,25 +1016,36 @@ PREV_MODEL=""
|
||||
PREV_ENV_VAR=""
|
||||
PREV_SUB_MODE=""
|
||||
if [ -f "$HIVE_CONFIG_FILE" ]; then
|
||||
eval "$($PYTHON_CMD -c "
|
||||
import json, sys
|
||||
eval "$(uv run python - 2>/dev/null <<'PY'
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
cfg_path = Path.home() / ".hive" / "configuration.json"
|
||||
try:
|
||||
with open('$HIVE_CONFIG_FILE') as f:
|
||||
with open(cfg_path, encoding="utf-8-sig") as f:
|
||||
c = json.load(f)
|
||||
llm = c.get('llm', {})
|
||||
print(f'PREV_PROVIDER={llm.get(\"provider\", \"\")}')
|
||||
print(f'PREV_MODEL={llm.get(\"model\", \"\")}')
|
||||
print(f'PREV_ENV_VAR={llm.get(\"api_key_env_var\", \"\")}')
|
||||
sub = ''
|
||||
if llm.get('use_claude_code_subscription'): sub = 'claude_code'
|
||||
elif llm.get('use_codex_subscription'): sub = 'codex'
|
||||
elif llm.get('use_kimi_code_subscription'): sub = 'kimi_code'
|
||||
elif llm.get('provider', '') == 'minimax' or 'api.minimax.io' in llm.get('api_base', ''): sub = 'minimax_code'
|
||||
elif 'api.z.ai' in llm.get('api_base', ''): sub = 'zai_code'
|
||||
print(f'PREV_SUB_MODE={sub}')
|
||||
llm = c.get("llm", {})
|
||||
print(f"PREV_PROVIDER={llm.get(\"provider\", \"\")}")
|
||||
print(f"PREV_MODEL={llm.get(\"model\", \"\")}")
|
||||
print(f"PREV_ENV_VAR={llm.get(\"api_key_env_var\", \"\")}")
|
||||
sub = ""
|
||||
if llm.get("use_claude_code_subscription"):
|
||||
sub = "claude_code"
|
||||
elif llm.get("use_codex_subscription"):
|
||||
sub = "codex"
|
||||
elif llm.get("use_kimi_code_subscription"):
|
||||
sub = "kimi_code"
|
||||
elif llm.get("provider", "") == "minimax" or "api.minimax.io" in llm.get("api_base", ""):
|
||||
sub = "minimax_code"
|
||||
elif llm.get("provider", "") == "hive" or "adenhq.com" in llm.get("api_base", ""):
|
||||
sub = "hive_llm"
|
||||
elif "api.z.ai" in llm.get("api_base", ""):
|
||||
sub = "zai_code"
|
||||
print(f"PREV_SUB_MODE={sub}")
|
||||
except Exception:
|
||||
pass
|
||||
" 2>/dev/null)" || true
|
||||
PY
|
||||
)" || true
|
||||
fi
|
||||
|
||||
# Compute default menu number from previous config (only if credential is still valid)
|
||||
@@ -917,6 +1057,7 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
zai_code) [ "$ZAI_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
codex) [ "$CODEX_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
kimi_code) [ "$KIMI_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
hive_llm) [ "$HIVE_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
*)
|
||||
# API key provider — check if the env var is set
|
||||
if [ -n "$PREV_ENV_VAR" ] && [ -n "${!PREV_ENV_VAR}" ]; then
|
||||
@@ -932,16 +1073,19 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
codex) DEFAULT_CHOICE=3 ;;
|
||||
minimax_code) DEFAULT_CHOICE=4 ;;
|
||||
kimi_code) DEFAULT_CHOICE=5 ;;
|
||||
hive_llm) DEFAULT_CHOICE=6 ;;
|
||||
esac
|
||||
if [ -z "$DEFAULT_CHOICE" ]; then
|
||||
case "$PREV_PROVIDER" in
|
||||
anthropic) DEFAULT_CHOICE=6 ;;
|
||||
openai) DEFAULT_CHOICE=7 ;;
|
||||
gemini) DEFAULT_CHOICE=8 ;;
|
||||
groq) DEFAULT_CHOICE=9 ;;
|
||||
cerebras) DEFAULT_CHOICE=10 ;;
|
||||
anthropic) DEFAULT_CHOICE=7 ;;
|
||||
openai) DEFAULT_CHOICE=8 ;;
|
||||
gemini) DEFAULT_CHOICE=9 ;;
|
||||
groq) DEFAULT_CHOICE=10 ;;
|
||||
cerebras) DEFAULT_CHOICE=11 ;;
|
||||
openrouter) DEFAULT_CHOICE=12 ;;
|
||||
minimax) DEFAULT_CHOICE=4 ;;
|
||||
kimi) DEFAULT_CHOICE=5 ;;
|
||||
hive) DEFAULT_CHOICE=6 ;;
|
||||
esac
|
||||
fi
|
||||
fi
|
||||
@@ -987,14 +1131,21 @@ else
|
||||
echo -e " ${CYAN}5)${NC} Kimi Code Subscription ${DIM}(use your Kimi Code plan)${NC}"
|
||||
fi
|
||||
|
||||
# 6) Hive LLM
|
||||
if [ "$HIVE_CRED_DETECTED" = true ]; then
|
||||
echo -e " ${CYAN}6)${NC} Hive LLM ${DIM}(use your Hive API key)${NC} ${GREEN}(credential detected)${NC}"
|
||||
else
|
||||
echo -e " ${CYAN}6)${NC} Hive LLM ${DIM}(use your Hive API key)${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e " ${CYAN}${BOLD}API key providers:${NC}"
|
||||
|
||||
# 6-10) API key providers — show (credential detected) if key already set
|
||||
PROVIDER_MENU_ENVS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GROQ_API_KEY CEREBRAS_API_KEY)
|
||||
PROVIDER_MENU_NAMES=("Anthropic (Claude) - Recommended" "OpenAI (GPT)" "Google Gemini - Free tier available" "Groq - Fast, free tier" "Cerebras - Fast, free tier")
|
||||
for idx in 0 1 2 3 4; do
|
||||
num=$((idx + 6))
|
||||
# 7-12) API key providers — show (credential detected) if key already set
|
||||
PROVIDER_MENU_ENVS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GROQ_API_KEY CEREBRAS_API_KEY OPENROUTER_API_KEY)
|
||||
PROVIDER_MENU_NAMES=("Anthropic (Claude) - Recommended" "OpenAI (GPT)" "Google Gemini - Free tier available" "Groq - Fast, free tier" "Cerebras - Fast, free tier" "OpenRouter - Bring any OpenRouter model")
|
||||
for idx in "${!PROVIDER_MENU_ENVS[@]}"; do
|
||||
num=$((idx + 7))
|
||||
env_var="${PROVIDER_MENU_ENVS[$idx]}"
|
||||
if [ -n "${!env_var}" ]; then
|
||||
echo -e " ${CYAN}$num)${NC} ${PROVIDER_MENU_NAMES[$idx]} ${GREEN}(credential detected)${NC}"
|
||||
@@ -1003,7 +1154,8 @@ for idx in 0 1 2 3 4; do
|
||||
fi
|
||||
done
|
||||
|
||||
echo -e " ${CYAN}11)${NC} Skip for now"
|
||||
SKIP_CHOICE=$((7 + ${#PROVIDER_MENU_ENVS[@]}))
|
||||
echo -e " ${CYAN}$SKIP_CHOICE)${NC} Skip for now"
|
||||
echo ""
|
||||
|
||||
if [ -n "$DEFAULT_CHOICE" ]; then
|
||||
@@ -1013,15 +1165,15 @@ fi
|
||||
|
||||
while true; do
|
||||
if [ -n "$DEFAULT_CHOICE" ]; then
|
||||
read -r -p "Enter choice (1-11) [$DEFAULT_CHOICE]: " choice || true
|
||||
read -r -p "Enter choice (1-$SKIP_CHOICE) [$DEFAULT_CHOICE]: " choice || true
|
||||
choice="${choice:-$DEFAULT_CHOICE}"
|
||||
else
|
||||
read -r -p "Enter choice (1-11): " choice || true
|
||||
read -r -p "Enter choice (1-$SKIP_CHOICE): " choice || true
|
||||
fi
|
||||
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le 11 ]; then
|
||||
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "$SKIP_CHOICE" ]; then
|
||||
break
|
||||
fi
|
||||
echo -e "${RED}Invalid choice. Please enter 1-11${NC}"
|
||||
echo -e "${RED}Invalid choice. Please enter 1-$SKIP_CHOICE${NC}"
|
||||
done
|
||||
|
||||
case $choice in
|
||||
@@ -1039,7 +1191,7 @@ case $choice in
|
||||
SELECTED_PROVIDER_ID="anthropic"
|
||||
SELECTED_MODEL="claude-opus-4-6"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=180000 # Claude — 200k context window
|
||||
SELECTED_MAX_CONTEXT_TOKENS=960000 # Claude — 1M context window
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using Claude Code subscription"
|
||||
fi
|
||||
@@ -1051,7 +1203,7 @@ case $choice in
|
||||
SELECTED_ENV_VAR="ZAI_API_KEY"
|
||||
SELECTED_MODEL="glm-5"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # GLM-5 — 128k context window
|
||||
SELECTED_MAX_CONTEXT_TOKENS=180000 # GLM-5 — 200k context window
|
||||
PROVIDER_NAME="ZAI"
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using ZAI Code subscription"
|
||||
@@ -1109,7 +1261,7 @@ case $choice in
|
||||
SELECTED_ENV_VAR="KIMI_API_KEY"
|
||||
SELECTED_MODEL="kimi-k2.5"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # Kimi K2.5 — 128k context window
|
||||
SELECTED_MAX_CONTEXT_TOKENS=240000 # Kimi K2.5 — 256k context window
|
||||
SELECTED_API_BASE="https://api.kimi.com/coding"
|
||||
PROVIDER_NAME="Kimi"
|
||||
SIGNUP_URL="https://www.kimi.com/code"
|
||||
@@ -1118,36 +1270,70 @@ case $choice in
|
||||
echo -e " ${DIM}Model: kimi-k2.5 | API: api.kimi.com/coding${NC}"
|
||||
;;
|
||||
6)
|
||||
# Hive LLM
|
||||
SUBSCRIPTION_MODE="hive_llm"
|
||||
SELECTED_PROVIDER_ID="hive"
|
||||
SELECTED_ENV_VAR="HIVE_API_KEY"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=180000
|
||||
SELECTED_API_BASE="$HIVE_LLM_ENDPOINT"
|
||||
PROVIDER_NAME="Hive"
|
||||
SIGNUP_URL="https://discord.com/invite/hQdU7QDkgR"
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using Hive LLM"
|
||||
echo ""
|
||||
echo -e " Select a model:"
|
||||
echo -e " ${CYAN}1)${NC} queen ${DIM}(default — Hive flagship)${NC}"
|
||||
echo -e " ${CYAN}2)${NC} kimi-2.5"
|
||||
echo -e " ${CYAN}3)${NC} GLM-5"
|
||||
echo ""
|
||||
read -r -p " Enter model choice (1-3) [1]: " hive_model_choice || true
|
||||
hive_model_choice="${hive_model_choice:-1}"
|
||||
case "$hive_model_choice" in
|
||||
2) SELECTED_MODEL="kimi-2.5" ;;
|
||||
3) SELECTED_MODEL="GLM-5" ;;
|
||||
*) SELECTED_MODEL="queen" ;;
|
||||
esac
|
||||
echo -e " ${DIM}Model: $SELECTED_MODEL | API: ${HIVE_LLM_ENDPOINT}${NC}"
|
||||
;;
|
||||
7)
|
||||
SELECTED_ENV_VAR="ANTHROPIC_API_KEY"
|
||||
SELECTED_PROVIDER_ID="anthropic"
|
||||
PROVIDER_NAME="Anthropic"
|
||||
SIGNUP_URL="https://console.anthropic.com/settings/keys"
|
||||
;;
|
||||
7)
|
||||
8)
|
||||
SELECTED_ENV_VAR="OPENAI_API_KEY"
|
||||
SELECTED_PROVIDER_ID="openai"
|
||||
PROVIDER_NAME="OpenAI"
|
||||
SIGNUP_URL="https://platform.openai.com/api-keys"
|
||||
;;
|
||||
8)
|
||||
9)
|
||||
SELECTED_ENV_VAR="GEMINI_API_KEY"
|
||||
SELECTED_PROVIDER_ID="gemini"
|
||||
PROVIDER_NAME="Google Gemini"
|
||||
SIGNUP_URL="https://aistudio.google.com/apikey"
|
||||
;;
|
||||
9)
|
||||
10)
|
||||
SELECTED_ENV_VAR="GROQ_API_KEY"
|
||||
SELECTED_PROVIDER_ID="groq"
|
||||
PROVIDER_NAME="Groq"
|
||||
SIGNUP_URL="https://console.groq.com/keys"
|
||||
;;
|
||||
10)
|
||||
11)
|
||||
SELECTED_ENV_VAR="CEREBRAS_API_KEY"
|
||||
SELECTED_PROVIDER_ID="cerebras"
|
||||
PROVIDER_NAME="Cerebras"
|
||||
SIGNUP_URL="https://cloud.cerebras.ai/"
|
||||
;;
|
||||
11)
|
||||
12)
|
||||
SELECTED_ENV_VAR="OPENROUTER_API_KEY"
|
||||
SELECTED_PROVIDER_ID="openrouter"
|
||||
SELECTED_API_BASE="https://openrouter.ai/api/v1"
|
||||
PROVIDER_NAME="OpenRouter"
|
||||
SIGNUP_URL="https://openrouter.ai/keys"
|
||||
;;
|
||||
"$SKIP_CHOICE")
|
||||
echo ""
|
||||
echo -e "${YELLOW}Skipped.${NC} An LLM API key is required to test and use worker agents."
|
||||
echo -e "Add your API key later by running:"
|
||||
@@ -1160,7 +1346,7 @@ case $choice in
|
||||
esac
|
||||
|
||||
# For API-key providers: prompt for key (allow replacement if already set)
|
||||
if { [ -z "$SUBSCRIPTION_MODE" ] || [ "$SUBSCRIPTION_MODE" = "minimax_code" ] || [ "$SUBSCRIPTION_MODE" = "kimi_code" ]; } && [ -n "$SELECTED_ENV_VAR" ]; then
|
||||
if { [ -z "$SUBSCRIPTION_MODE" ] || [ "$SUBSCRIPTION_MODE" = "minimax_code" ] || [ "$SUBSCRIPTION_MODE" = "kimi_code" ] || [ "$SUBSCRIPTION_MODE" = "hive_llm" ]; } && [ -n "$SELECTED_ENV_VAR" ]; then
|
||||
while true; do
|
||||
CURRENT_KEY="${!SELECTED_ENV_VAR}"
|
||||
if [ -n "$CURRENT_KEY" ]; then
|
||||
@@ -1188,7 +1374,7 @@ if { [ -z "$SUBSCRIPTION_MODE" ] || [ "$SUBSCRIPTION_MODE" = "minimax_code" ] ||
|
||||
echo -e "${GREEN}⬢${NC} API key saved to $SHELL_RC_FILE"
|
||||
# Health check the new key
|
||||
echo -n " Verifying API key... "
|
||||
if { [ "$SUBSCRIPTION_MODE" = "minimax_code" ] || [ "$SUBSCRIPTION_MODE" = "kimi_code" ]; } && [ -n "${SELECTED_API_BASE:-}" ]; then
|
||||
if [ -n "${SELECTED_API_BASE:-}" ]; then
|
||||
HC_RESULT=$(uv run python "$SCRIPT_DIR/scripts/check_llm_key.py" "$SELECTED_PROVIDER_ID" "$API_KEY" "$SELECTED_API_BASE" 2>/dev/null) || true
|
||||
else
|
||||
HC_RESULT=$(uv run python "$SCRIPT_DIR/scripts/check_llm_key.py" "$SELECTED_PROVIDER_ID" "$API_KEY" 2>/dev/null) || true
|
||||
@@ -1300,18 +1486,28 @@ fi
|
||||
if [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
echo ""
|
||||
echo -n " Saving configuration... "
|
||||
SAVE_OK=true
|
||||
if [ "$SUBSCRIPTION_MODE" = "claude_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "true" "" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "true" "" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "codex" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "" "true" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "" "true" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "zai_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "https://api.z.ai/api/coding/paas/v4" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "https://api.z.ai/api/coding/paas/v4" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "minimax_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "kimi_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "hive_llm" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SELECTED_PROVIDER_ID" = "openrouter" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
|
||||
else
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" > /dev/null
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" > /dev/null || SAVE_OK=false
|
||||
fi
|
||||
if [ "$SAVE_OK" = false ]; then
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e "${YELLOW} Could not write ~/.hive/configuration.json. Please rerun quickstart.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${GREEN}⬢${NC}"
|
||||
echo -e " ${DIM}~/.hive/configuration.json${NC}"
|
||||
@@ -1327,22 +1523,44 @@ echo -e "${GREEN}⬢${NC} Browser automation enabled"
|
||||
|
||||
# Patch gcu_enabled into configuration.json
|
||||
if [ -f "$HIVE_CONFIG_FILE" ]; then
|
||||
uv run python -c "
|
||||
if ! uv run python - <<'PY'
|
||||
import json
|
||||
with open('$HIVE_CONFIG_FILE') as f:
|
||||
from pathlib import Path
|
||||
|
||||
cfg_path = Path.home() / ".hive" / "configuration.json"
|
||||
with open(cfg_path, encoding="utf-8-sig") as f:
|
||||
config = json.load(f)
|
||||
config['gcu_enabled'] = True
|
||||
with open('$HIVE_CONFIG_FILE', 'w') as f:
|
||||
config["gcu_enabled"] = True
|
||||
tmp_path = cfg_path.with_name(cfg_path.name + ".tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
"
|
||||
tmp_path.replace(cfg_path)
|
||||
PY
|
||||
then
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e "${YELLOW} Could not update ~/.hive/configuration.json with browser automation settings.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
mkdir -p "$HIVE_CONFIG_DIR"
|
||||
uv run python -c "
|
||||
if ! uv run python - "$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")" <<'PY'
|
||||
import json
|
||||
config = {'gcu_enabled': True, 'created_at': '$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")'}
|
||||
with open('$HIVE_CONFIG_FILE', 'w') as f:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
cfg_path = Path.home() / ".hive" / "configuration.json"
|
||||
cfg_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
config = {
|
||||
"gcu_enabled": True,
|
||||
"created_at": sys.argv[1],
|
||||
}
|
||||
with open(cfg_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
"
|
||||
PY
|
||||
then
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e "${YELLOW} Could not create ~/.hive/configuration.json for browser automation settings.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
@@ -1543,6 +1761,9 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
elif [ "$SUBSCRIPTION_MODE" = "minimax_code" ]; then
|
||||
echo -e " ${GREEN}⬢${NC} MiniMax Coding Key → ${DIM}$SELECTED_MODEL${NC}"
|
||||
echo -e " ${DIM}API: api.minimax.io/v1 (OpenAI-compatible)${NC}"
|
||||
elif [ "$SELECTED_PROVIDER_ID" = "openrouter" ]; then
|
||||
echo -e " ${GREEN}⬢${NC} OpenRouter API Key → ${DIM}$SELECTED_MODEL${NC}"
|
||||
echo -e " ${DIM}API: openrouter.ai/api/v1 (OpenAI-compatible)${NC}"
|
||||
else
|
||||
echo -e " ${CYAN}$SELECTED_PROVIDER_ID${NC} → ${DIM}$SELECTED_MODEL${NC}"
|
||||
fi
|
||||
@@ -1587,40 +1808,29 @@ if [ "$CODEX_AVAILABLE" = true ]; then
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Auto-launch dashboard if frontend was built
|
||||
if [ "$FRONTEND_BUILT" = true ]; then
|
||||
echo -e "${BOLD}Launching dashboard...${NC}"
|
||||
echo ""
|
||||
echo -e " ${DIM}Starting server on http://localhost:8787${NC}"
|
||||
echo -e " ${DIM}Press Ctrl+C to stop${NC}"
|
||||
echo ""
|
||||
echo -e " ${DIM}Tip: You can restart the dashboard anytime with:${NC} ${CYAN}hive open${NC}"
|
||||
echo ""
|
||||
# exec replaces the quickstart process with hive open
|
||||
exec "$SCRIPT_DIR/hive" open
|
||||
else
|
||||
# No frontend — show manual instructions
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo -e "${BOLD}⚠️ IMPORTANT: Load your new configuration${NC}"
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo ""
|
||||
echo -e " Your API keys have been saved to ${CYAN}$SHELL_RC_FILE${NC}"
|
||||
echo -e " To use them, either:"
|
||||
echo ""
|
||||
echo -e " ${GREEN}Option 1:${NC} Source your shell config now:"
|
||||
echo -e " ${CYAN}source $SHELL_RC_FILE${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}Option 2:${NC} Open a new terminal window"
|
||||
echo ""
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo -e "${BOLD}IMPORTANT: Load your new configuration${NC}"
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo ""
|
||||
echo -e " Your API keys have been saved to ${CYAN}$SHELL_RC_FILE${NC}"
|
||||
echo -e " To use them, either:"
|
||||
echo ""
|
||||
echo -e " ${GREEN}Option 1:${NC} Source your shell config now:"
|
||||
echo -e " ${CYAN}source $SHELL_RC_FILE${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}Option 2:${NC} Open a new terminal window"
|
||||
echo ""
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo ""
|
||||
|
||||
echo -e "${BOLD}Run an Agent:${NC}"
|
||||
echo ""
|
||||
echo -e " Launch the interactive dashboard to browse and run agents:"
|
||||
echo -e " You can start an example agent or an agent built by yourself:"
|
||||
echo -e " ${CYAN}hive open${NC}"
|
||||
echo ""
|
||||
echo -e "${DIM}Run ./quickstart.sh again to reconfigure.${NC}"
|
||||
echo ""
|
||||
echo -e "${BOLD}Run an Agent:${NC}"
|
||||
echo ""
|
||||
if [ "$FRONTEND_BUILT" = true ]; then
|
||||
echo -e " Quickstart only sets things up. Launch the dashboard when you're ready:"
|
||||
else
|
||||
echo -e " Frontend build was skipped or failed. Once the dashboard is available, launch it with:"
|
||||
fi
|
||||
echo -e " ${CYAN}hive open${NC}"
|
||||
echo ""
|
||||
echo -e "${DIM}Run ./quickstart.sh again to reconfigure.${NC}"
|
||||
echo ""
|
||||
|
||||
+208
-3
@@ -1,7 +1,7 @@
|
||||
"""Validate an LLM API key without consuming tokens.
|
||||
|
||||
Usage:
|
||||
python scripts/check_llm_key.py <provider_id> <api_key> [api_base]
|
||||
python scripts/check_llm_key.py <provider_id> <api_key> [api_base] [model]
|
||||
|
||||
Exit codes:
|
||||
0 = valid key
|
||||
@@ -12,11 +12,125 @@ Output: single JSON line {"valid": bool, "message": str}
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import unicodedata
|
||||
from difflib import get_close_matches
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.config import HIVE_LLM_ENDPOINT
|
||||
|
||||
TIMEOUT = 10.0
|
||||
OPENROUTER_SEPARATOR_TRANSLATION = str.maketrans(
|
||||
{
|
||||
"\u2010": "-",
|
||||
"\u2011": "-",
|
||||
"\u2012": "-",
|
||||
"\u2013": "-",
|
||||
"\u2014": "-",
|
||||
"\u2015": "-",
|
||||
"\u2212": "-",
|
||||
"\u2044": "/",
|
||||
"\u2215": "/",
|
||||
"\u29F8": "/",
|
||||
"\uFF0F": "/",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _extract_error_message(response: httpx.Response) -> str:
|
||||
"""Best-effort extraction of a provider error message."""
|
||||
try:
|
||||
payload = response.json()
|
||||
except Exception:
|
||||
text = (response.text or "").strip()
|
||||
return text[:240] if text else ""
|
||||
|
||||
if isinstance(payload, dict):
|
||||
error_value = payload.get("error")
|
||||
if isinstance(error_value, dict):
|
||||
message = error_value.get("message")
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
if isinstance(error_value, str) and error_value.strip():
|
||||
return error_value.strip()
|
||||
message = payload.get("message")
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _sanitize_openrouter_model_id(value: str) -> str:
|
||||
"""Sanitize pasted OpenRouter model IDs into a comparable slug."""
|
||||
normalized = unicodedata.normalize("NFKC", value or "")
|
||||
normalized = "".join(
|
||||
ch
|
||||
for ch in normalized
|
||||
if unicodedata.category(ch) not in {"Cc", "Cf"}
|
||||
)
|
||||
normalized = normalized.translate(OPENROUTER_SEPARATOR_TRANSLATION)
|
||||
normalized = re.sub(r"\s+", "", normalized)
|
||||
if normalized.casefold().startswith("openrouter/"):
|
||||
normalized = normalized.split("/", 1)[1]
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_openrouter_model_id(value: str) -> str:
|
||||
"""Normalize OpenRouter model IDs for exact/alias matching."""
|
||||
return _sanitize_openrouter_model_id(value).casefold()
|
||||
|
||||
|
||||
def _extract_openrouter_model_lookup(payload: object) -> dict[str, str]:
|
||||
"""Map normalized model IDs/aliases to a preferred canonical display slug."""
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
|
||||
data = payload.get("data")
|
||||
if not isinstance(data, list):
|
||||
return {}
|
||||
|
||||
lookup: dict[str, str] = {}
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
model_id = item.get("id")
|
||||
canonical_slug = item.get("canonical_slug")
|
||||
candidates = [
|
||||
_sanitize_openrouter_model_id(value)
|
||||
for value in (model_id, canonical_slug)
|
||||
if isinstance(value, str) and _sanitize_openrouter_model_id(value)
|
||||
]
|
||||
if not candidates:
|
||||
continue
|
||||
|
||||
preferred_slug = candidates[-1]
|
||||
for candidate in candidates:
|
||||
lookup[_normalize_openrouter_model_id(candidate)] = preferred_slug
|
||||
|
||||
return lookup
|
||||
|
||||
|
||||
def _format_openrouter_model_unavailable_message(
|
||||
model: str, available_model_lookup: dict[str, str]
|
||||
) -> str:
|
||||
"""Return a helpful not-found message with close-match suggestions."""
|
||||
suggestions = [
|
||||
available_model_lookup[key]
|
||||
for key in get_close_matches(
|
||||
_normalize_openrouter_model_id(model),
|
||||
list(available_model_lookup),
|
||||
n=1,
|
||||
cutoff=0.6,
|
||||
)
|
||||
]
|
||||
|
||||
base = f"OpenRouter model is not available for this key/settings: {model}"
|
||||
if suggestions:
|
||||
return f"{base}. Closest matches: {', '.join(suggestions)}"
|
||||
return base
|
||||
|
||||
|
||||
def check_anthropic(api_key: str, **_: str) -> dict:
|
||||
@@ -56,6 +170,79 @@ def check_openai_compatible(api_key: str, endpoint: str, name: str) -> dict:
|
||||
return {"valid": False, "message": f"{name} API returned status {r.status_code}"}
|
||||
|
||||
|
||||
def check_openrouter(
|
||||
api_key: str, api_base: str = "https://openrouter.ai/api/v1", **_: str
|
||||
) -> dict:
|
||||
"""Validate OpenRouter key against GET /models."""
|
||||
endpoint = f"{api_base.rstrip('/')}/models"
|
||||
with httpx.Client(timeout=TIMEOUT) as client:
|
||||
r = client.get(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
if r.status_code in (200, 429):
|
||||
return {"valid": True, "message": "OpenRouter API key valid"}
|
||||
if r.status_code == 401:
|
||||
return {"valid": False, "message": "Invalid OpenRouter API key"}
|
||||
if r.status_code == 403:
|
||||
return {"valid": False, "message": "OpenRouter API key lacks permissions"}
|
||||
return {"valid": False, "message": f"OpenRouter API returned status {r.status_code}"}
|
||||
|
||||
|
||||
def check_openrouter_model(
|
||||
api_key: str,
|
||||
model: str,
|
||||
api_base: str = "https://openrouter.ai/api/v1",
|
||||
**_: str,
|
||||
) -> dict:
|
||||
"""Validate that an OpenRouter model ID is available to this key/settings."""
|
||||
requested_model = _sanitize_openrouter_model_id(model)
|
||||
endpoint = f"{api_base.rstrip('/')}/models/user"
|
||||
with httpx.Client(timeout=TIMEOUT) as client:
|
||||
r = client.get(
|
||||
endpoint,
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if r.status_code == 200:
|
||||
available_model_lookup = _extract_openrouter_model_lookup(r.json())
|
||||
matched_model = available_model_lookup.get(
|
||||
_normalize_openrouter_model_id(requested_model)
|
||||
)
|
||||
if matched_model:
|
||||
return {
|
||||
"valid": True,
|
||||
"message": f"OpenRouter model is available: {matched_model}",
|
||||
"model": matched_model,
|
||||
}
|
||||
|
||||
return {
|
||||
"valid": False,
|
||||
"message": _format_openrouter_model_unavailable_message(
|
||||
requested_model, available_model_lookup
|
||||
),
|
||||
}
|
||||
if r.status_code == 429:
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model check rate-limited; assuming model is reachable",
|
||||
}
|
||||
if r.status_code == 401:
|
||||
return {"valid": False, "message": "Invalid OpenRouter API key"}
|
||||
if r.status_code == 403:
|
||||
return {"valid": False, "message": "OpenRouter API key lacks permissions"}
|
||||
|
||||
detail = _extract_error_message(r)
|
||||
if r.status_code in (400, 404, 422):
|
||||
base = (
|
||||
"OpenRouter model is not available for this key/settings: "
|
||||
f"{requested_model}"
|
||||
)
|
||||
return {"valid": False, "message": f"{base}. {detail}" if detail else base}
|
||||
|
||||
suffix = f": {detail}" if detail else ""
|
||||
return {
|
||||
"valid": False,
|
||||
"message": f"OpenRouter model check returned status {r.status_code}{suffix}",
|
||||
}
|
||||
|
||||
|
||||
def check_minimax(
|
||||
api_key: str, api_base: str = "https://api.minimax.io/v1", **_: str
|
||||
) -> dict:
|
||||
@@ -129,12 +316,17 @@ PROVIDERS = {
|
||||
"cerebras": lambda key, **kw: check_openai_compatible(
|
||||
key, "https://api.cerebras.ai/v1/models", "Cerebras"
|
||||
),
|
||||
"openrouter": lambda key, **kw: check_openrouter(key, **kw),
|
||||
"minimax": lambda key, **kw: check_minimax(key),
|
||||
# Kimi For Coding uses an Anthropic-compatible endpoint; check via /v1/messages
|
||||
# with empty messages (same as check_anthropic, triggers 400 not 401).
|
||||
"kimi": lambda key, **kw: check_anthropic_compatible(
|
||||
key, "https://api.kimi.com/coding/v1/messages", "Kimi"
|
||||
),
|
||||
# Hive LLM uses an Anthropic-compatible endpoint
|
||||
"hive": lambda key, **kw: check_anthropic_compatible(
|
||||
key, f"{HIVE_LLM_ENDPOINT}/v1/messages", "Hive"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -144,7 +336,7 @@ def main() -> None:
|
||||
json.dumps(
|
||||
{
|
||||
"valid": False,
|
||||
"message": "Usage: check_llm_key.py <provider> <key> [api_base]",
|
||||
"message": "Usage: check_llm_key.py <provider> <key> [api_base] [model]",
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -153,15 +345,28 @@ def main() -> None:
|
||||
provider_id = sys.argv[1]
|
||||
api_key = sys.argv[2]
|
||||
api_base = sys.argv[3] if len(sys.argv) > 3 else ""
|
||||
model = sys.argv[4] if len(sys.argv) > 4 else ""
|
||||
|
||||
try:
|
||||
if api_base and provider_id == "minimax":
|
||||
if provider_id == "openrouter" and model:
|
||||
result = check_openrouter_model(
|
||||
api_key,
|
||||
model=model,
|
||||
api_base=(api_base or "https://openrouter.ai/api/v1"),
|
||||
)
|
||||
elif api_base and provider_id == "minimax":
|
||||
result = check_minimax(api_key, api_base)
|
||||
elif api_base and provider_id == "openrouter":
|
||||
result = check_openrouter(api_key, api_base)
|
||||
elif api_base and provider_id == "kimi":
|
||||
# Kimi uses an Anthropic-compatible endpoint; check via /v1/messages
|
||||
result = check_anthropic_compatible(
|
||||
api_key, api_base.rstrip("/") + "/v1/messages", "Kimi"
|
||||
)
|
||||
elif api_base and provider_id == "hive":
|
||||
result = check_anthropic_compatible(
|
||||
api_key, api_base.rstrip("/") + "/v1/messages", "Hive"
|
||||
)
|
||||
elif api_base:
|
||||
# Custom API base (ZAI or other OpenAI-compatible)
|
||||
endpoint = api_base.rstrip("/") + "/models"
|
||||
|
||||
@@ -67,7 +67,7 @@ SLACK_CREDENTIALS = {
|
||||
help_url="https://api.slack.com/apps",
|
||||
description="Slack Bot Token (starts with xoxb-)",
|
||||
# Auth method support
|
||||
aden_supported=True,
|
||||
aden_supported=False,
|
||||
aden_provider_name="slack",
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Slack Bot Token:
|
||||
|
||||
@@ -12,7 +12,7 @@ import zlib
|
||||
|
||||
# Files beyond this size are skipped/rejected in hashline mode because
|
||||
# hashline anchors are not practical on files this large (minified
|
||||
# bundles, logs, data dumps). Shared by view_file, grep_search, and
|
||||
# bundles, logs, data dumps). Shared by read_file, grep_search, and
|
||||
# hashline_edit.
|
||||
HASHLINE_MAX_FILE_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
@@ -70,8 +70,6 @@ from .file_system_toolkits.list_dir import register_tools as register_list_dir
|
||||
from .file_system_toolkits.replace_file_content import (
|
||||
register_tools as register_replace_file_content,
|
||||
)
|
||||
from .file_system_toolkits.view_file import register_tools as register_view_file
|
||||
from .file_system_toolkits.write_to_file import register_tools as register_write_to_file
|
||||
from .github_tool import register_tools as register_github
|
||||
from .gitlab_tool import register_tools as register_gitlab
|
||||
from .gmail_tool import register_tools as register_gmail
|
||||
@@ -186,14 +184,12 @@ def _register_verified(
|
||||
register_account_info(mcp, credentials=credentials)
|
||||
|
||||
# --- File system toolkits ---
|
||||
register_view_file(mcp)
|
||||
register_write_to_file(mcp)
|
||||
register_list_dir(mcp)
|
||||
register_replace_file_content(mcp)
|
||||
register_apply_diff(mcp)
|
||||
register_apply_patch(mcp)
|
||||
register_grep_search(mcp)
|
||||
# hashline_edit: anchor-based editing, pairs with view_file/grep_search hashline mode
|
||||
# hashline_edit: anchor-based editing, pairs with read_file/grep_search hashline mode
|
||||
register_hashline_edit(mcp)
|
||||
register_execute_command(mcp)
|
||||
register_data_tools(mcp)
|
||||
|
||||
@@ -75,7 +75,7 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
try:
|
||||
if hashline:
|
||||
# Use splitlines() for anchor consistency with
|
||||
# view_file/hashline_edit (handles Unicode line
|
||||
# read_file/hashline_edit (handles Unicode line
|
||||
# separators like \u2028, \x85).
|
||||
# Skip files > 10MB to avoid excessive memory use.
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
@@ -6,11 +6,11 @@ Edit files using anchor-based line references for precise, hash-validated edits.
|
||||
|
||||
The `hashline_edit` tool enables file editing using short content-hash anchors (`N:hhhh`) instead of requiring exact text reproduction. Each line's anchor includes a 4-character hash of its content. If the file has changed since the model last read it, the hash won't match and the edit is cleanly rejected.
|
||||
|
||||
Use this tool together with `view_file(hashline=True)` and `grep_search(hashline=True)`, which return anchors for each line.
|
||||
Use this tool together with `read_file(hashline=True)` and `grep_search(hashline=True)`, which return anchors for each line.
|
||||
|
||||
## Use Cases
|
||||
|
||||
- Making targeted edits after reading a file with `view_file(hashline=True)`
|
||||
- Making targeted edits after reading a file with `read_file(hashline=True)`
|
||||
- Replacing single lines, line ranges, or inserting new lines by anchor
|
||||
- Batch editing multiple locations in a single atomic call
|
||||
- Falling back to string replacement when anchors are not available
|
||||
@@ -21,7 +21,7 @@ Use this tool together with `view_file(hashline=True)` and `grep_search(hashline
|
||||
import json
|
||||
|
||||
# First, read the file with hashline mode to get anchors
|
||||
content = view_file(path="app.py", hashline=True, workspace_id="ws-1", agent_id="a-1", session_id="s-1")
|
||||
content = read_file(path="app.py", hashline=True)
|
||||
# Returns lines like: 1:a3b1|def main(): 2:f1c2| print("hello") ...
|
||||
|
||||
# Then edit using the anchors
|
||||
@@ -29,25 +29,10 @@ hashline_edit(
|
||||
path="app.py",
|
||||
edits=json.dumps([
|
||||
{"op": "set_line", "anchor": "2:f1c2", "content": ' print("goodbye")'}
|
||||
]),
|
||||
workspace_id="ws-1",
|
||||
agent_id="a-1",
|
||||
session_id="s-1"
|
||||
])
|
||||
)
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
| Argument | Type | Required | Default | Description |
|
||||
|----------|------|----------|---------|-------------|
|
||||
| `path` | str | Yes | - | The path to the file (relative to session root) |
|
||||
| `edits` | str | Yes | - | JSON string containing a list of edit operations (see Operations below) |
|
||||
| `workspace_id` | str | Yes | - | The ID of the workspace |
|
||||
| `agent_id` | str | Yes | - | The ID of the agent |
|
||||
| `session_id` | str | Yes | - | The ID of the current session |
|
||||
| `auto_cleanup` | bool | No | `True` | Strip hashline prefixes and echoed context from content. Set to `False` to write content exactly as provided. |
|
||||
| `encoding` | str | No | `"utf-8"` | File encoding. Must match the file's actual encoding. |
|
||||
|
||||
## Operations
|
||||
|
||||
The `edits` parameter is a JSON array of operation objects. Each object must have an `"op"` field:
|
||||
@@ -61,62 +46,6 @@ The `edits` parameter is a JSON array of operation objects. Each object must hav
|
||||
| `replace` | `old_content`, `new_content`, `allow_multiple` (optional) | Fallback string replacement; errors if 0 or 2+ matches (unless `allow_multiple: true`) |
|
||||
| `append` | `content` | Append new lines to end of file (works for empty files too) |
|
||||
|
||||
## Returns
|
||||
|
||||
**Success:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"path": "app.py",
|
||||
"edits_applied": 2,
|
||||
"content": "1:b2c4|def main():\n2:c4a1| print(\"goodbye\")\n..."
|
||||
}
|
||||
```
|
||||
|
||||
**Success (noop, content unchanged after applying edits):**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"path": "app.py",
|
||||
"edits_applied": 0,
|
||||
"note": "Content unchanged after applying edits",
|
||||
"content": "1:b2c4|def main():\n..."
|
||||
}
|
||||
```
|
||||
|
||||
**Success (with auto-cleanup applied):**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"path": "app.py",
|
||||
"edits_applied": 1,
|
||||
"content": "...",
|
||||
"cleanup_applied": ["prefix_strip"]
|
||||
}
|
||||
```
|
||||
|
||||
The `cleanup_applied` field is only present when cleanup actually modified content. Possible values: `prefix_strip`, `boundary_echo_strip`, `insert_echo_strip`.
|
||||
|
||||
**Success (replace with allow_multiple):**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"path": "app.py",
|
||||
"edits_applied": 1,
|
||||
"content": "...",
|
||||
"replacements": {"edit_1": 3}
|
||||
}
|
||||
```
|
||||
|
||||
The `replacements` field is only present when `allow_multiple: true` was used, showing the count per replace op.
|
||||
|
||||
**Error:**
|
||||
```python
|
||||
{
|
||||
"error": "Edit #1 (set_line): Hash mismatch at line 2: expected 'f1c2', got 'a3b1'. Re-read the file to get current anchors."
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Returns an error if the file doesn't exist
|
||||
@@ -127,90 +56,11 @@ The `replacements` field is only present when `allow_multiple: true` was used, s
|
||||
- Returns an error for unknown op types or invalid JSON
|
||||
- All edits are validated before any writes occur (atomic): on any error the file is unchanged
|
||||
|
||||
## Examples
|
||||
|
||||
### Replacing a single line
|
||||
```python
|
||||
edits = json.dumps([
|
||||
{"op": "set_line", "anchor": "5:a3b1", "content": " return result"}
|
||||
])
|
||||
result = hashline_edit(path="app.py", edits=edits, workspace_id="ws-1", agent_id="a-1", session_id="s-1")
|
||||
# Returns: {"success": True, "path": "app.py", "edits_applied": 1, "content": "..."}
|
||||
```
|
||||
|
||||
### Replacing a range of lines
|
||||
```python
|
||||
edits = json.dumps([{
|
||||
"op": "replace_lines",
|
||||
"start_anchor": "10:b1c2",
|
||||
"end_anchor": "15:c2d3",
|
||||
"content": " # simplified\n return x + y"
|
||||
}])
|
||||
result = hashline_edit(path="math.py", edits=edits, workspace_id="ws-1", agent_id="a-1", session_id="s-1")
|
||||
```
|
||||
|
||||
### Inserting new lines after
|
||||
```python
|
||||
edits = json.dumps([
|
||||
{"op": "insert_after", "anchor": "3:d4e5", "content": "import os\nimport sys"}
|
||||
])
|
||||
result = hashline_edit(path="app.py", edits=edits, workspace_id="ws-1", agent_id="a-1", session_id="s-1")
|
||||
```
|
||||
|
||||
### Inserting new lines before
|
||||
```python
|
||||
edits = json.dumps([
|
||||
{"op": "insert_before", "anchor": "1:a1b2", "content": "#!/usr/bin/env python3"}
|
||||
])
|
||||
result = hashline_edit(path="app.py", edits=edits, workspace_id="ws-1", agent_id="a-1", session_id="s-1")
|
||||
```
|
||||
|
||||
### Batch editing
|
||||
```python
|
||||
edits = json.dumps([
|
||||
{"op": "set_line", "anchor": "1:a1b2", "content": "#!/usr/bin/env python3"},
|
||||
{"op": "insert_after", "anchor": "2:b2c3", "content": "import logging"},
|
||||
{"op": "set_line", "anchor": "10:c3d4", "content": " logging.info('done')"},
|
||||
])
|
||||
result = hashline_edit(path="app.py", edits=edits, workspace_id="ws-1", agent_id="a-1", session_id="s-1")
|
||||
```
|
||||
|
||||
### Replace all occurrences
|
||||
```python
|
||||
edits = json.dumps([
|
||||
{"op": "replace", "old_content": "old_name", "new_content": "new_name", "allow_multiple": True}
|
||||
])
|
||||
result = hashline_edit(path="app.py", edits=edits, workspace_id="ws-1", agent_id="a-1", session_id="s-1")
|
||||
# Returns: {..., "replacements": {"edit_1": 5}}
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Anchors are generated by `view_file(hashline=True)` and `grep_search(hashline=True)`
|
||||
- Anchors are generated by `read_file(hashline=True)` and `grep_search(hashline=True)`
|
||||
- The hash is a CRC32-based 4-char hex digest of the line content (with trailing spaces and tabs stripped; leading whitespace is included so indentation changes invalidate anchors). Collision probability is ~0.0015% per changed line.
|
||||
- All anchor-based ops are validated before any writes occur; if any op fails validation, the file is left unchanged
|
||||
- String `replace` ops are applied after all anchor-based splices, so they match against post-splice content
|
||||
- Original line endings (LF or CRLF) are preserved
|
||||
- The response includes the updated file content in hashline format, so subsequent edits can use the new anchors without re-reading
|
||||
|
||||
## Auto-Cleanup Details
|
||||
|
||||
When `auto_cleanup=True` (the default), the tool strips hashline prefixes and echoed context that LLMs frequently include in edit content. Prefix stripping uses a **2+ non-empty line threshold** to avoid false positives. The prefix regex matches the `N:hhhh|` pattern (4-char hex hash).
|
||||
|
||||
**Why the threshold matters:** Single-line content matching the `N:hhhh|` pattern is ambiguous. It could be literal content (CSV data, config values, log format strings) that happens to match the pattern. With 2+ lines all matching, the probability of a false positive drops dramatically.
|
||||
|
||||
**Single-line example (NOT stripped):**
|
||||
```python
|
||||
# set_line with content "5:a3b1|hello" writes literally "5:a3b1|hello"
|
||||
{"op": "set_line", "anchor": "2:f1c2", "content": "5:a3b1|hello"}
|
||||
```
|
||||
|
||||
**Multi-line example (stripped):**
|
||||
```python
|
||||
# replace_lines where all lines match N:hhhh| pattern gets stripped
|
||||
{"op": "replace_lines", "start_anchor": "2:f1c2", "end_anchor": "3:b2d3",
|
||||
"content": "2:a3b1|BBB\n3:c4d2|CCC"}
|
||||
# Writes "BBB\nCCC" (prefixes removed)
|
||||
```
|
||||
|
||||
**Escape hatch:** Set `auto_cleanup=False` to write content exactly as provided, bypassing all cleanup heuristics.
|
||||
|
||||
@@ -39,7 +39,7 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
Edit a file using anchor-based line references (N:hash) for precise edits.
|
||||
|
||||
When to use
|
||||
After reading a file with view_file(hashline=True), use the anchors to make
|
||||
After reading a file with read_file(hashline=True), use the anchors to make
|
||||
targeted edits without reproducing exact file content.
|
||||
|
||||
Rules & Constraints
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
# View File Tool
|
||||
|
||||
Reads the content of a file within the secure session sandbox.
|
||||
|
||||
## Description
|
||||
|
||||
The `view_file` tool allows you to read and retrieve the complete content of files within a sandboxed session environment. It provides metadata about the file along with its content.
|
||||
|
||||
## Use Cases
|
||||
|
||||
- Reading configuration files
|
||||
- Viewing source code
|
||||
- Inspecting log files
|
||||
- Retrieving data files for processing
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
view_file(
|
||||
path="config/settings.json",
|
||||
workspace_id="workspace-123",
|
||||
agent_id="agent-456",
|
||||
session_id="session-789"
|
||||
)
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
| Argument | Type | Required | Default | Description |
|
||||
|----------|------|----------|---------|-------------|
|
||||
| `path` | str | Yes | - | The path to the file (relative to session root) |
|
||||
| `workspace_id` | str | Yes | - | The ID of the workspace |
|
||||
| `agent_id` | str | Yes | - | The ID of the agent |
|
||||
| `session_id` | str | Yes | - | The ID of the current session |
|
||||
| `encoding` | str | No | `"utf-8"` | The encoding to use for reading the file |
|
||||
| `max_size` | int | No | `10485760` | Maximum size of file content to return in bytes (10 MB) |
|
||||
| `hashline` | bool | No | `False` | If True, return content with `N:hhhh\|content` anchors for use with `hashline_edit` |
|
||||
| `offset` | int | No | `1` | 1-indexed start line (only used when `hashline=True`) |
|
||||
| `limit` | int | No | `0` | Max lines to return, 0 = all (only used when `hashline=True`) |
|
||||
|
||||
## Returns
|
||||
|
||||
Returns a dictionary with the following structure:
|
||||
|
||||
**Success (default mode):**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"path": "config/settings.json",
|
||||
"content": "{\"debug\": true}",
|
||||
"size_bytes": 16,
|
||||
"lines": 1
|
||||
}
|
||||
```
|
||||
|
||||
**Success (hashline mode):**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"path": "app.py",
|
||||
"content": "1:a3f2|def main():\n2:f1c4| print(\"hello\")",
|
||||
"hashline": True,
|
||||
"offset": 1,
|
||||
"limit": 0,
|
||||
"total_lines": 2,
|
||||
"shown_lines": 2,
|
||||
"size_bytes": 35
|
||||
}
|
||||
```
|
||||
|
||||
**Error:**
|
||||
```python
|
||||
{
|
||||
"error": "File not found at config/settings.json"
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Returns an error dict if the file doesn't exist
|
||||
- Returns an error dict if the file cannot be read (permission issues, encoding errors, etc.)
|
||||
- Handles binary files gracefully by returning appropriate error messages
|
||||
|
||||
## Examples
|
||||
|
||||
### Reading a text file
|
||||
```python
|
||||
result = view_file(
|
||||
path="README.md",
|
||||
workspace_id="ws-1",
|
||||
agent_id="agent-1",
|
||||
session_id="session-1"
|
||||
)
|
||||
# Returns: {"success": True, "path": "README.md", "content": "# My Project\n...", "size_bytes": 1024, "lines": 42}
|
||||
```
|
||||
|
||||
### Handling missing files
|
||||
```python
|
||||
result = view_file(
|
||||
path="nonexistent.txt",
|
||||
workspace_id="ws-1",
|
||||
agent_id="agent-1",
|
||||
session_id="session-1"
|
||||
)
|
||||
# Returns: {"error": "File not found at nonexistent.txt"}
|
||||
```
|
||||
@@ -1,3 +0,0 @@
|
||||
from .view_file import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -1,134 +0,0 @@
|
||||
import os
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from aden_tools.hashline import HASHLINE_MAX_FILE_BYTES, format_hashlines
|
||||
|
||||
from ..security import get_secure_path
|
||||
|
||||
|
||||
def register_tools(mcp: FastMCP) -> None:
|
||||
"""Register file view tools with the MCP server."""
|
||||
if getattr(mcp, "_file_tools_registered", False):
|
||||
return
|
||||
mcp._file_tools_registered = True
|
||||
|
||||
@mcp.tool()
|
||||
def view_file(
|
||||
path: str,
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
encoding: str = "utf-8",
|
||||
max_size: int = HASHLINE_MAX_FILE_BYTES,
|
||||
hashline: bool = False,
|
||||
offset: int = 1,
|
||||
limit: int = 0,
|
||||
) -> dict:
|
||||
"""
|
||||
Purpose
|
||||
Read the content of a file within the session sandbox.
|
||||
|
||||
When to use
|
||||
Inspect file contents before making changes
|
||||
Retrieve stored data or configuration
|
||||
Review logs or artifacts
|
||||
|
||||
Rules & Constraints
|
||||
File must exist at the specified path
|
||||
Returns full content with size and line count
|
||||
Always read before patching or modifying
|
||||
|
||||
Args:
|
||||
path: The path to the file (relative to session root)
|
||||
workspace_id: The ID of workspace
|
||||
agent_id: The ID of agent
|
||||
session_id: The ID of the current session
|
||||
encoding: The encoding to use for reading the file (default: "utf-8")
|
||||
max_size: The maximum size of file content to return in bytes (default: 10MB)
|
||||
hashline: If True, return content with N:hhhh|content anchors
|
||||
for use with hashline_edit (default: False)
|
||||
offset: 1-indexed start line, only used when hashline=True (default: 1)
|
||||
limit: Max lines to return, 0 = all, only used when hashline=True (default: 0)
|
||||
|
||||
Returns:
|
||||
Dict with file content and metadata, or error dict
|
||||
"""
|
||||
try:
|
||||
if max_size < 0:
|
||||
return {"error": f"max_size must be non-negative, got {max_size}"}
|
||||
|
||||
secure_path = get_secure_path(path, workspace_id, agent_id, session_id)
|
||||
if not os.path.exists(secure_path):
|
||||
return {"error": f"File not found at {path}"}
|
||||
|
||||
if not os.path.isfile(secure_path):
|
||||
return {"error": f"Path is not a file: {path}"}
|
||||
|
||||
with open(secure_path, encoding=encoding) as f:
|
||||
content_raw = f.read()
|
||||
|
||||
if not hashline and (offset != 1 or limit != 0):
|
||||
return {
|
||||
"error": "offset and limit are only supported when hashline=True. "
|
||||
"Set hashline=True to use paging."
|
||||
}
|
||||
|
||||
if hashline:
|
||||
if offset < 1:
|
||||
return {"error": f"offset must be >= 1, got {offset}"}
|
||||
if limit < 0:
|
||||
return {"error": f"limit must be >= 0, got {limit}"}
|
||||
|
||||
all_lines = content_raw.splitlines()
|
||||
total_lines = len(all_lines)
|
||||
raw_size = len(content_raw.encode(encoding))
|
||||
|
||||
if offset > max(total_lines, 1):
|
||||
return {"error": f"offset {offset} is beyond end of file ({total_lines} lines)"}
|
||||
|
||||
# Check size after considering offset/limit. When paging
|
||||
# (offset or limit set), only check the formatted output size.
|
||||
# When reading the full file, check the raw size.
|
||||
is_paging = offset > 1 or limit > 0
|
||||
if not is_paging and raw_size > max_size:
|
||||
return {
|
||||
"error": f"File too large for hashline mode ({raw_size} bytes, "
|
||||
f"max {max_size}). Use offset and limit to read a section at a time."
|
||||
}
|
||||
|
||||
formatted = format_hashlines(all_lines, offset=offset, limit=limit)
|
||||
shown_lines = len(formatted.splitlines()) if formatted else 0
|
||||
|
||||
if is_paging and len(formatted.encode(encoding)) > max_size:
|
||||
return {
|
||||
"error": f"Requested section too large ({shown_lines} lines). "
|
||||
f"Reduce limit to read a smaller section."
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"path": path,
|
||||
"content": formatted,
|
||||
"hashline": True,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"total_lines": total_lines,
|
||||
"shown_lines": shown_lines,
|
||||
"size_bytes": raw_size,
|
||||
}
|
||||
|
||||
content = content_raw
|
||||
if len(content.encode(encoding)) > max_size:
|
||||
content = content[:max_size]
|
||||
content += "\n\n[... Content truncated due to size limit ...]"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"path": path,
|
||||
"content": content,
|
||||
"size_bytes": len(content.encode(encoding)),
|
||||
"lines": len(content.splitlines()),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to read file: {str(e)}"}
|
||||
@@ -1,92 +0,0 @@
|
||||
# Write to File Tool
|
||||
|
||||
Writes content to a file within the secure session sandbox. Supports both overwriting and appending modes.
|
||||
|
||||
## Description
|
||||
|
||||
The `write_to_file` tool allows you to create new files or modify existing files within a sandboxed session environment. It automatically creates parent directories if they don't exist and provides flexible write modes.
|
||||
|
||||
## Use Cases
|
||||
|
||||
- Creating new configuration files
|
||||
- Writing generated code or data
|
||||
- Appending logs or output to existing files
|
||||
- Saving processed results to disk
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
write_to_file(
|
||||
path="config/settings.json",
|
||||
content='{"debug": true}',
|
||||
workspace_id="workspace-123",
|
||||
agent_id="agent-456",
|
||||
session_id="session-789",
|
||||
append=False
|
||||
)
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
| Argument | Type | Required | Default | Description |
|
||||
|----------|------|----------|---------|-------------|
|
||||
| `path` | str | Yes | - | The path to the file (relative to session root) |
|
||||
| `content` | str | Yes | - | The content to write to the file |
|
||||
| `workspace_id` | str | Yes | - | The ID of the workspace |
|
||||
| `agent_id` | str | Yes | - | The ID of the agent |
|
||||
| `session_id` | str | Yes | - | The ID of the current session |
|
||||
| `append` | bool | No | False | Whether to append to the file instead of overwriting |
|
||||
|
||||
## Returns
|
||||
|
||||
Returns a dictionary with the following structure:
|
||||
|
||||
**Success:**
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"path": "config/settings.json",
|
||||
"mode": "written", # or "appended"
|
||||
"bytes_written": 18
|
||||
}
|
||||
```
|
||||
|
||||
**Error:**
|
||||
```python
|
||||
{
|
||||
"error": "Failed to write to file: [error message]"
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Returns an error dict if the file cannot be written (permission issues, invalid path, etc.)
|
||||
- Automatically creates parent directories if they don't exist
|
||||
- Handles encoding errors gracefully
|
||||
|
||||
## Examples
|
||||
|
||||
### Creating a new file
|
||||
```python
|
||||
result = write_to_file(
|
||||
path="data/output.txt",
|
||||
content="Hello, world!",
|
||||
workspace_id="ws-1",
|
||||
agent_id="agent-1",
|
||||
session_id="session-1"
|
||||
)
|
||||
# Returns: {"success": True, "path": "data/output.txt", "mode": "written", "bytes_written": 13}
|
||||
```
|
||||
|
||||
### Appending to a file
|
||||
```python
|
||||
result = write_to_file(
|
||||
path="logs/activity.log",
|
||||
content="\n[INFO] Task completed",
|
||||
workspace_id="ws-1",
|
||||
agent_id="agent-1",
|
||||
session_id="session-1",
|
||||
append=True
|
||||
)
|
||||
# Returns: {"success": True, "path": "logs/activity.log", "mode": "appended", "bytes_written": 24}
|
||||
```
|
||||
@@ -1,3 +0,0 @@
|
||||
from .write_to_file import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -1,61 +0,0 @@
|
||||
import os
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from ..security import get_secure_path
|
||||
|
||||
|
||||
def register_tools(mcp: FastMCP) -> None:
|
||||
"""Register file write tools with the MCP server."""
|
||||
|
||||
@mcp.tool()
|
||||
def write_to_file(
|
||||
path: str,
|
||||
content: str,
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
append: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Purpose
|
||||
Create a new file or append content to an existing file.
|
||||
|
||||
When to use
|
||||
Append new events to append-only logs
|
||||
Create new artifacts or summaries
|
||||
Initialize new canonical memory files
|
||||
|
||||
Rules & Constraints
|
||||
Must not overwrite canonical memory unless explicitly allowed
|
||||
Should include structured data (JSON, Markdown with headers)
|
||||
Every write must be intentional and minimal
|
||||
|
||||
Anti-pattern
|
||||
Do NOT dump raw conversation transcripts without structure or reason.
|
||||
|
||||
Args:
|
||||
path: The path to the file (relative to session root)
|
||||
content: The content to write to the file
|
||||
workspace_id: The ID of the workspace
|
||||
agent_id: The ID of the agent
|
||||
session_id: The ID of the current session
|
||||
append: Whether to append to the file instead of overwriting (default: False)
|
||||
|
||||
Returns:
|
||||
Dict with success status and path, or error dict
|
||||
"""
|
||||
try:
|
||||
secure_path = get_secure_path(path, workspace_id, agent_id, session_id)
|
||||
os.makedirs(os.path.dirname(secure_path), exist_ok=True)
|
||||
mode = "a" if append else "w"
|
||||
with open(secure_path, mode, encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return {
|
||||
"success": True,
|
||||
"path": path,
|
||||
"mode": "appended" if append else "written",
|
||||
"bytes_written": len(content.encode("utf-8")),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to write to file: {str(e)}"}
|
||||
@@ -296,6 +296,7 @@ def register_tools(
|
||||
include_grid_data: bool = False,
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -325,6 +326,7 @@ def register_tools(
|
||||
sheet_titles: list[str] | None = None,
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -357,6 +359,7 @@ def register_tools(
|
||||
value_render_option: str = "FORMATTED_VALUE",
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -392,6 +395,7 @@ def register_tools(
|
||||
value_input_option: str = "USER_ENTERED",
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -426,6 +430,7 @@ def register_tools(
|
||||
value_input_option: str = "USER_ENTERED",
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -458,6 +463,7 @@ def register_tools(
|
||||
range_name: str,
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -490,6 +496,7 @@ def register_tools(
|
||||
value_input_option: str = "USER_ENTERED",
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -521,6 +528,7 @@ def register_tools(
|
||||
ranges: list[str],
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -554,6 +562,7 @@ def register_tools(
|
||||
column_count: int = 26,
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
@@ -585,6 +594,7 @@ def register_tools(
|
||||
sheet_id: int,
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
account: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict:
|
||||
|
||||
@@ -409,6 +409,8 @@ class BrowserSession:
|
||||
We're already inside ``self._lock`` so we can't call ``stop()``.
|
||||
This mirrors the teardown logic without re-acquiring the lock.
|
||||
"""
|
||||
_CLOSE_TIMEOUT = 10.0 # seconds
|
||||
|
||||
if self.cdp_port:
|
||||
from .port_manager import release_port
|
||||
|
||||
@@ -417,21 +419,21 @@ class BrowserSession:
|
||||
|
||||
if self.context:
|
||||
try:
|
||||
await self.context.close()
|
||||
await asyncio.wait_for(self.context.close(), timeout=_CLOSE_TIMEOUT)
|
||||
except Exception:
|
||||
pass
|
||||
self.context = None
|
||||
|
||||
if self.browser:
|
||||
try:
|
||||
await self.browser.close()
|
||||
await asyncio.wait_for(self.browser.close(), timeout=_CLOSE_TIMEOUT)
|
||||
except Exception:
|
||||
pass
|
||||
self.browser = None
|
||||
|
||||
if self._playwright:
|
||||
try:
|
||||
await self._playwright.stop()
|
||||
await asyncio.wait_for(self._playwright.stop(), timeout=_CLOSE_TIMEOUT)
|
||||
except Exception:
|
||||
pass
|
||||
self._playwright = None
|
||||
@@ -588,6 +590,10 @@ class BrowserSession:
|
||||
|
||||
async def stop(self) -> dict:
|
||||
"""Stop the browser and clean up resources."""
|
||||
# Timeout for each Playwright teardown call — prevents hanging when
|
||||
# the browser process is crashed or unresponsive.
|
||||
_CLOSE_TIMEOUT = 10.0 # seconds
|
||||
|
||||
async with self._lock:
|
||||
# Release CDP port if allocated
|
||||
if self.cdp_port:
|
||||
@@ -598,23 +604,35 @@ class BrowserSession:
|
||||
|
||||
# Close context (works for both persistent and ephemeral)
|
||||
if self.context:
|
||||
await self.context.close()
|
||||
try:
|
||||
await asyncio.wait_for(self.context.close(), timeout=_CLOSE_TIMEOUT)
|
||||
except Exception as exc:
|
||||
logger.warning("context.close() failed for profile %r: %s", self.profile, exc)
|
||||
self.context = None
|
||||
|
||||
# Agent sessions share a browser — don't close it (other agents depend on it).
|
||||
# Only standard sessions own their browser and playwright instances.
|
||||
if self.session_type != "agent":
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
try:
|
||||
await asyncio.wait_for(self.browser.close(), timeout=_CLOSE_TIMEOUT)
|
||||
except Exception as exc:
|
||||
logger.warning("browser.close() failed for profile %r: %s", self.profile, exc)
|
||||
self.browser = None
|
||||
|
||||
if self._playwright:
|
||||
await self._playwright.stop()
|
||||
try:
|
||||
await asyncio.wait_for(self._playwright.stop(), timeout=_CLOSE_TIMEOUT)
|
||||
except Exception as exc:
|
||||
logger.warning("playwright.stop() failed for profile %r: %s", self.profile, exc)
|
||||
self._playwright = None
|
||||
|
||||
# Kill the Chrome subprocess
|
||||
if self._chrome_process:
|
||||
await self._chrome_process.kill()
|
||||
try:
|
||||
await self._chrome_process.kill()
|
||||
except Exception as exc:
|
||||
logger.warning("chrome_process.kill() failed for profile %r: %s", self.profile, exc)
|
||||
self._chrome_process = None
|
||||
else:
|
||||
self.browser = None # Drop reference to shared browser
|
||||
|
||||
@@ -32,290 +32,42 @@ def mock_secure_path(tmp_path):
|
||||
return os.path.join(tmp_path, path)
|
||||
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.view_file.view_file.get_secure_path",
|
||||
"aden_tools.tools.file_system_toolkits.list_dir.list_dir.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.write_to_file.write_to_file.get_secure_path",
|
||||
"aden_tools.tools.file_system_toolkits.replace_file_content.replace_file_content.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.list_dir.list_dir.get_secure_path",
|
||||
"aden_tools.tools.file_system_toolkits.apply_diff.apply_diff.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.replace_file_content.replace_file_content.get_secure_path",
|
||||
"aden_tools.tools.file_system_toolkits.apply_patch.apply_patch.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.apply_diff.apply_diff.get_secure_path",
|
||||
"aden_tools.tools.file_system_toolkits.grep_search.grep_search.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.apply_patch.apply_patch.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
"aden_tools.tools.file_system_toolkits.grep_search.grep_search.WORKSPACES_DIR",
|
||||
str(tmp_path),
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.grep_search.grep_search.get_secure_path",
|
||||
"aden_tools.tools.file_system_toolkits.execute_command_tool.execute_command_tool.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.grep_search.grep_search.WORKSPACES_DIR",
|
||||
"aden_tools.tools.file_system_toolkits.execute_command_tool.execute_command_tool.WORKSPACES_DIR",
|
||||
str(tmp_path),
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.execute_command_tool.execute_command_tool.get_secure_path",
|
||||
"aden_tools.tools.file_system_toolkits.hashline_edit.hashline_edit.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.execute_command_tool.execute_command_tool.WORKSPACES_DIR",
|
||||
str(tmp_path),
|
||||
):
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.hashline_edit.hashline_edit.get_secure_path",
|
||||
side_effect=_get_secure_path,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestViewFileTool:
|
||||
"""Tests for view_file tool."""
|
||||
|
||||
@pytest.fixture
|
||||
def view_file_fn(self, mcp):
|
||||
from aden_tools.tools.file_system_toolkits.view_file import register_tools
|
||||
|
||||
register_tools(mcp)
|
||||
return mcp._tool_manager._tools["view_file"].fn
|
||||
|
||||
def test_view_existing_file(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Viewing an existing file returns content and metadata."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("Hello, World!", encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="test.txt", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["content"] == "Hello, World!"
|
||||
assert result["size_bytes"] == len(b"Hello, World!")
|
||||
assert result["lines"] == 1
|
||||
|
||||
def test_view_nonexistent_file(self, view_file_fn, mock_workspace, mock_secure_path):
|
||||
"""Viewing a non-existent file returns an error."""
|
||||
result = view_file_fn(path="nonexistent.txt", **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"].lower()
|
||||
|
||||
def test_view_multiline_file(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Viewing a multiline file returns correct line count."""
|
||||
test_file = tmp_path / "multiline.txt"
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\n"
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="multiline.txt", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["content"] == content
|
||||
assert result["lines"] == 4
|
||||
|
||||
def test_view_empty_file(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Viewing an empty file returns empty content."""
|
||||
test_file = tmp_path / "empty.txt"
|
||||
test_file.write_text("", encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="empty.txt", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["content"] == ""
|
||||
assert result["size_bytes"] == 0
|
||||
assert result["lines"] == 0
|
||||
|
||||
def test_view_file_with_unicode(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Viewing a file with unicode characters works correctly."""
|
||||
test_file = tmp_path / "unicode.txt"
|
||||
content = "Hello 世界! 🌍 émoji"
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="unicode.txt", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["content"] == content
|
||||
assert result["size_bytes"] == len(content.encode("utf-8"))
|
||||
|
||||
def test_view_nested_file(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Viewing a file in a nested directory works correctly."""
|
||||
nested = tmp_path / "nested" / "dir"
|
||||
nested.mkdir(parents=True)
|
||||
test_file = nested / "file.txt"
|
||||
test_file.write_text("nested content", encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="nested/dir/file.txt", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["content"] == "nested content"
|
||||
|
||||
def test_view_file_with_max_size_truncation(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Viewing a file with max_size truncates content when exceeding limit."""
|
||||
test_file = tmp_path / "large.txt"
|
||||
content = "x" * 1000
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="large.txt", max_size=100, **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["content"]) <= 100 + len(
|
||||
"\n\n[... Content truncated due to size limit ...]"
|
||||
)
|
||||
assert "[... Content truncated due to size limit ...]" in result["content"]
|
||||
|
||||
def test_view_file_with_negative_max_size(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Viewing a file with negative max_size returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("content", encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="test.txt", max_size=-1, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "max_size must be non-negative" in result["error"]
|
||||
|
||||
def test_view_file_with_custom_encoding(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Viewing a file with custom encoding works correctly."""
|
||||
test_file = tmp_path / "encoded.txt"
|
||||
content = "Hello 世界"
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="encoded.txt", encoding="utf-8", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["content"] == content
|
||||
|
||||
def test_view_file_with_invalid_encoding(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Viewing a file with invalid encoding returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("content", encoding="utf-8")
|
||||
|
||||
result = view_file_fn(path="test.txt", encoding="invalid-encoding", **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "Failed to read file" in result["error"]
|
||||
|
||||
def test_offset_without_hashline_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Passing offset without hashline=True returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\nccc\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", offset=5, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "hashline=True" in result["error"]
|
||||
|
||||
def test_limit_without_hashline_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Passing limit without hashline=True returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\nccc\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", limit=10, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "hashline=True" in result["error"]
|
||||
|
||||
def test_offset_and_limit_without_hashline_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Passing both offset and limit without hashline=True returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\nccc\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", offset=2, limit=5, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "hashline=True" in result["error"]
|
||||
|
||||
|
||||
class TestWriteToFileTool:
|
||||
"""Tests for write_to_file tool."""
|
||||
|
||||
@pytest.fixture
|
||||
def write_to_file_fn(self, mcp):
|
||||
from aden_tools.tools.file_system_toolkits.write_to_file import register_tools
|
||||
|
||||
register_tools(mcp)
|
||||
return mcp._tool_manager._tools["write_to_file"].fn
|
||||
|
||||
def test_write_new_file(self, write_to_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Writing to a new file creates it successfully."""
|
||||
result = write_to_file_fn(path="new_file.txt", content="Test content", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["mode"] == "written"
|
||||
assert result["bytes_written"] > 0
|
||||
|
||||
# Verify file was created
|
||||
created_file = tmp_path / "new_file.txt"
|
||||
assert created_file.exists()
|
||||
assert created_file.read_text(encoding="utf-8") == "Test content"
|
||||
|
||||
def test_write_append_mode(self, write_to_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Writing with append=True appends to existing file."""
|
||||
test_file = tmp_path / "append_test.txt"
|
||||
test_file.write_text("Line 1\n", encoding="utf-8")
|
||||
|
||||
result = write_to_file_fn(
|
||||
path="append_test.txt", content="Line 2\n", append=True, **mock_workspace
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["mode"] == "appended"
|
||||
assert test_file.read_text(encoding="utf-8") == "Line 1\nLine 2\n"
|
||||
|
||||
def test_write_overwrite_existing(
|
||||
self, write_to_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Writing to existing file overwrites it by default."""
|
||||
test_file = tmp_path / "overwrite.txt"
|
||||
test_file.write_text("Original content", encoding="utf-8")
|
||||
|
||||
result = write_to_file_fn(path="overwrite.txt", content="New content", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["mode"] == "written"
|
||||
assert test_file.read_text(encoding="utf-8") == "New content"
|
||||
|
||||
def test_write_creates_parent_directories(
|
||||
self, write_to_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Writing creates parent directories if they don't exist."""
|
||||
result = write_to_file_fn(path="nested/dir/file.txt", content="Test", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
created_file = tmp_path / "nested" / "dir" / "file.txt"
|
||||
assert created_file.exists()
|
||||
assert created_file.read_text(encoding="utf-8") == "Test"
|
||||
|
||||
def test_write_empty_content(
|
||||
self, write_to_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Writing empty content creates empty file."""
|
||||
result = write_to_file_fn(path="empty.txt", content="", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["bytes_written"] == 0
|
||||
created_file = tmp_path / "empty.txt"
|
||||
assert created_file.exists()
|
||||
assert created_file.read_text(encoding="utf-8") == ""
|
||||
yield
|
||||
|
||||
|
||||
class TestListDirTool:
|
||||
@@ -805,167 +557,6 @@ class TestApplyPatchTool:
|
||||
assert test_file.read_text(encoding="utf-8") == modified
|
||||
|
||||
|
||||
class TestViewFileHashlineMode:
|
||||
"""Tests for view_file hashline mode."""
|
||||
|
||||
@pytest.fixture
|
||||
def view_file_fn(self, mcp):
|
||||
from aden_tools.tools.file_system_toolkits.view_file import register_tools
|
||||
|
||||
register_tools(mcp)
|
||||
return mcp._tool_manager._tools["view_file"].fn
|
||||
|
||||
def test_hashline_format(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""hashline=True returns N:hhhh|content format."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("hello\nworld\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["hashline"] is True
|
||||
lines = result["content"].split("\n")
|
||||
assert lines[0].startswith("1:")
|
||||
assert "|hello" in lines[0]
|
||||
assert lines[1].startswith("2:")
|
||||
assert "|world" in lines[1]
|
||||
|
||||
def test_hashline_offset(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""hashline with offset skips initial lines."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\nccc\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, offset=2, **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["offset"] == 2
|
||||
lines = result["content"].split("\n")
|
||||
assert lines[0].startswith("2:")
|
||||
assert "|bbb" in lines[0]
|
||||
|
||||
def test_hashline_limit(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""hashline with limit restricts number of lines."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\nccc\nddd\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, limit=2, **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["limit"] == 2
|
||||
assert result["shown_lines"] == 2
|
||||
assert result["total_lines"] == 4
|
||||
|
||||
def test_hashline_total_and_shown_lines(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""total_lines and shown_lines are reported correctly."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("a\nb\nc\nd\ne\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, offset=2, limit=2, **mock_workspace)
|
||||
|
||||
assert result["total_lines"] == 5
|
||||
assert result["shown_lines"] == 2
|
||||
|
||||
def test_default_mode_unchanged(self, view_file_fn, mock_workspace, mock_secure_path, tmp_path):
|
||||
"""Default mode (hashline=False) returns the same format as before."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("hello\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", **mock_workspace)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "hashline" not in result
|
||||
assert result["content"] == "hello\n"
|
||||
assert result["lines"] == 1
|
||||
|
||||
def test_hashline_offset_zero_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""hashline with offset=0 returns error (must be >= 1)."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, offset=0, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "offset" in result["error"].lower()
|
||||
|
||||
def test_hashline_negative_offset_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""hashline with negative offset returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, offset=-1, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "offset" in result["error"].lower()
|
||||
|
||||
def test_hashline_negative_limit_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""hashline with negative limit returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, limit=-1, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "limit" in result["error"].lower()
|
||||
|
||||
def test_hashline_truncated_file_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Large file with hashline=True and no offset/limit returns error directing to paginate."""
|
||||
test_file = tmp_path / "large.txt"
|
||||
# Create a file larger than the max_size we'll pass
|
||||
content = "line\n" * 100 # 500 bytes
|
||||
test_file.write_text(content)
|
||||
|
||||
result = view_file_fn(path="large.txt", hashline=True, max_size=50, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "too large" in result["error"].lower()
|
||||
assert "offset" in result["error"].lower()
|
||||
assert "limit" in result["error"].lower()
|
||||
|
||||
def test_hashline_offset_beyond_end_returns_error(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""hashline with offset beyond total lines returns error."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("aaa\nbbb\n")
|
||||
|
||||
result = view_file_fn(path="test.txt", hashline=True, offset=50, **mock_workspace)
|
||||
|
||||
assert "error" in result
|
||||
assert "beyond" in result["error"].lower()
|
||||
assert "2 lines" in result["error"]
|
||||
|
||||
def test_hashline_large_file_with_offset_limit_works(
|
||||
self, view_file_fn, mock_workspace, mock_secure_path, tmp_path
|
||||
):
|
||||
"""Large file using offset/limit bypasses full-file size check."""
|
||||
test_file = tmp_path / "large.txt"
|
||||
lines = [f"line {i}" for i in range(1, 101)]
|
||||
test_file.write_text("\n".join(lines) + "\n")
|
||||
|
||||
# File is large (> max_size=200), but offset/limit lets us page through it
|
||||
result = view_file_fn(
|
||||
path="large.txt", hashline=True, offset=10, limit=5, max_size=200, **mock_workspace
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["shown_lines"] == 5
|
||||
assert result["total_lines"] == 100
|
||||
# First shown line should be line 10
|
||||
first_line = result["content"].split("\n")[0]
|
||||
assert first_line.startswith("10:")
|
||||
assert "|line 10" in first_line
|
||||
|
||||
|
||||
class TestGrepSearchHashlineMode:
|
||||
"""Tests for grep_search hashline mode."""
|
||||
|
||||
@@ -1047,13 +638,6 @@ class TestGrepSearchHashlineMode:
|
||||
class TestHashlineCrossToolConsistency:
|
||||
"""Cross-tool consistency tests for hashline workflows."""
|
||||
|
||||
@pytest.fixture
|
||||
def view_file_fn(self, mcp):
|
||||
from aden_tools.tools.file_system_toolkits.view_file import register_tools
|
||||
|
||||
register_tools(mcp)
|
||||
return mcp._tool_manager._tools["view_file"].fn
|
||||
|
||||
@pytest.fixture
|
||||
def grep_search_fn(self, mcp):
|
||||
from aden_tools.tools.file_system_toolkits.grep_search import register_tools
|
||||
@@ -1070,7 +654,6 @@ class TestHashlineCrossToolConsistency:
|
||||
|
||||
def test_unicode_line_separator_anchor_roundtrip(
|
||||
self,
|
||||
view_file_fn,
|
||||
grep_search_fn,
|
||||
hashline_edit_fn,
|
||||
mock_workspace,
|
||||
@@ -1081,11 +664,6 @@ class TestHashlineCrossToolConsistency:
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("A\u2028B\nC\n", encoding="utf-8")
|
||||
|
||||
# Hashline view sees U+2028 as a line boundary via splitlines()
|
||||
view_res = view_file_fn(path="test.txt", hashline=True, **mock_workspace)
|
||||
assert view_res["success"] is True
|
||||
assert view_res["total_lines"] == 3
|
||||
|
||||
# grep_search line iteration treats U+2028 as in-line content
|
||||
grep_res = grep_search_fn(path="test.txt", pattern="B", hashline=True, **mock_workspace)
|
||||
assert grep_res["success"] is True
|
||||
|
||||
Reference in New Issue
Block a user