Compare commits
159 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 736756b257 | |||
| 90efe7009d | |||
| 4adb369bde | |||
| d4a30eb2f3 | |||
| 94bb4a2984 | |||
| 648bad26ed | |||
| f0c7470f3d | |||
| fe533b72a6 | |||
| e581767cab | |||
| 0663ee5950 | |||
| 4b97baa34b | |||
| a89296d397 | |||
| d568912ba2 | |||
| c4d7980058 | |||
| 8549fe8238 | |||
| 2b8d85bb95 | |||
| 07f7801166 | |||
| 1f12a45151 | |||
| 936e02e8e6 | |||
| d59fe1e109 | |||
| 274318d3e5 | |||
| 0f0884c2e0 | |||
| 764012c598 | |||
| fd4dc1a69a | |||
| 377cd39c2a | |||
| e92caeef24 | |||
| b7e6226478 | |||
| a995818db2 | |||
| 0772b4d300 | |||
| 684e0d8dc6 | |||
| d284c5d790 | |||
| 7a9b9666c4 | |||
| a852cb91bf | |||
| 2f21e9eb4b | |||
| 8390ef8731 | |||
| 8d21479c24 | |||
| 965dec3ba1 | |||
| d4b54446be | |||
| 7992b862c2 | |||
| 44b3e0eaa2 | |||
| f480fc2b94 | |||
| 2844dbf19f | |||
| 22b7e4b0c3 | |||
| 5413833a69 | |||
| 02e1a4584a | |||
| 520840b1dd | |||
| ee96147336 | |||
| 705cef4dc1 | |||
| ab26e64122 | |||
| f365e219cb | |||
| 01621881c2 | |||
| f7639f8572 | |||
| fc643060ce | |||
| 9aebeb181e | |||
| acbbfaaa79 | |||
| bf170bce10 | |||
| 0a090d058b | |||
| 47bfadaad9 | |||
| d968dcd44c | |||
| 6fdaa9ea50 | |||
| 4d251fbdc2 | |||
| 6acceed288 | |||
| 8dd1d6e3aa | |||
| 1da28644a6 | |||
| 6452fe7fef | |||
| acff008bd2 | |||
| 651d6850a1 | |||
| c7fdc92594 | |||
| 43602a8801 | |||
| 3da04265a6 | |||
| 4c98f0d2d0 | |||
| d84c3364d0 | |||
| ae921f6cee | |||
| 6b506a1c08 | |||
| 0c9f4fa97e | |||
| 95e30bc607 | |||
| 0f1f0090b0 | |||
| c0da3bec02 | |||
| 9dadb5264d | |||
| e39e6a75cc | |||
| 23c66d1059 | |||
| b9d529d94e | |||
| 1c9b09fb78 | |||
| 9fb14f23d2 | |||
| 4795dc4f68 | |||
| acf0f804c5 | |||
| 4e2951854b | |||
| 80dfb429d7 | |||
| 9c0ba77e22 | |||
| 46b4651073 | |||
| 86dd5246c6 | |||
| a1227c88ee | |||
| 535d7ab568 | |||
| af10494b31 | |||
| 39c1042827 | |||
| 16e7dc11f4 | |||
| 7a27babefd | |||
| d53ae9d51d | |||
| 910cf7727d | |||
| 1698605f15 | |||
| eda124a123 | |||
| 15e9ce8d2f | |||
| c01dd603d7 | |||
| 9d5157d69f | |||
| d78795bdf5 | |||
| ff2b7f473e | |||
| 73c9a91811 | |||
| 27b765d902 | |||
| fddba419be | |||
| f42d6308e8 | |||
| c167002754 | |||
| ea26ee7d0c | |||
| 5280e908b2 | |||
| 1c5dd8c664 | |||
| 3aca153be5 | |||
| 65c8e1653c | |||
| 58e4fa918c | |||
| 3af13d3f90 | |||
| b799789dbe | |||
| 2cd73dfccc | |||
| 57d77d5479 | |||
| 5814021773 | |||
| 4f4cc9c8ce | |||
| d9c840eee5 | |||
| d2eb86e534 | |||
| 03842353e4 | |||
| 48747e20af | |||
| 58af593af6 | |||
| 450575a927 | |||
| eac2bb19b2 | |||
| 756a815bf0 | |||
| 23a7b080eb | |||
| bf39bcdec9 | |||
| 0276632491 | |||
| ae2993d0d1 | |||
| d14d71f760 | |||
| 738641d35f | |||
| 22f5534f08 | |||
| b79e7eca73 | |||
| 28250dc45e | |||
| fe5df6a87a | |||
| 88253883a3 | |||
| ff7b5c7e27 | |||
| 6ed6e5b286 | |||
| 30bb0ad5d8 | |||
| cb0845f5ba | |||
| ce2525b59c | |||
| 1f77ec3831 | |||
| 6ab5aa8004 | |||
| 4449cd8ee8 | |||
| 8b60c03a0a | |||
| 0e98023e40 | |||
| 22bb07f00e | |||
| 660f883197 | |||
| 988de80b66 | |||
| dc6aa226ee | |||
| 48a54b4ee2 | |||
| a7b6b080ab | |||
| 9202cbd4d4 |
@@ -2,14 +2,22 @@ name: Bounty completed
|
||||
description: Awards points and notifies Discord when a bounty PR is merged
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr_number:
|
||||
description: "PR number to process (for missed bounties)"
|
||||
required: true
|
||||
type: number
|
||||
|
||||
jobs:
|
||||
bounty-notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:')
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:'))
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
@@ -32,6 +40,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_NUMBER: ${{ inputs.pr_number || github.event.pull_request.number }}
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
name: Link Discord account
|
||||
description: Auto-creates a PR to add contributor to contributors.yml when a link-discord issue is opened
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
link-discord:
|
||||
if: contains(github.event.issue.labels.*.name, 'link-discord')
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 2
|
||||
permissions:
|
||||
contents: write
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Parse issue and update contributors.yml
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
|
||||
const issue = context.payload.issue;
|
||||
const githubUsername = issue.user.login;
|
||||
|
||||
// Parse the issue body for form fields
|
||||
const body = issue.body || '';
|
||||
|
||||
// Extract Discord ID — look for the numeric value after the "Discord User ID" heading
|
||||
const discordMatch = body.match(/### Discord User ID\s*\n\s*(\d{17,20})/);
|
||||
if (!discordMatch) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `Could not find a valid Discord ID in the issue body. Please make sure you entered a numeric ID (17-20 digits), not a username.\n\nExample: \`123456789012345678\``
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'not_planned'
|
||||
});
|
||||
return;
|
||||
}
|
||||
const discordId = discordMatch[1];
|
||||
|
||||
// Extract display name (optional)
|
||||
const nameMatch = body.match(/### Display Name \(optional\)\s*\n\s*(.+)/);
|
||||
const displayName = nameMatch ? nameMatch[1].trim() : '';
|
||||
|
||||
// Check if user already exists
|
||||
const yml = fs.readFileSync('contributors.yml', 'utf-8');
|
||||
if (yml.includes(`github: ${githubUsername}`)) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `@${githubUsername} is already in \`contributors.yml\`. If you need to update your Discord ID, please edit the file directly via PR.`
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'completed'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Append entry to contributors.yml
|
||||
let entry = ` - github: ${githubUsername}\n discord: "${discordId}"`;
|
||||
if (displayName && displayName !== '_No response_') {
|
||||
entry += `\n name: ${displayName}`;
|
||||
}
|
||||
entry += '\n';
|
||||
|
||||
const updated = yml.trimEnd() + '\n' + entry;
|
||||
fs.writeFileSync('contributors.yml', updated);
|
||||
|
||||
// Set outputs for commit step
|
||||
core.exportVariable('GITHUB_USERNAME', githubUsername);
|
||||
core.exportVariable('DISCORD_ID', discordId);
|
||||
core.exportVariable('ISSUE_NUMBER', issue.number.toString());
|
||||
|
||||
- name: Create PR
|
||||
run: |
|
||||
# Check if there are changes
|
||||
if git diff --quiet contributors.yml; then
|
||||
echo "No changes to contributors.yml"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
BRANCH="docs/link-discord-${GITHUB_USERNAME}"
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git checkout -b "$BRANCH"
|
||||
git add contributors.yml
|
||||
git commit -m "docs: link @${GITHUB_USERNAME} to Discord"
|
||||
git push origin "$BRANCH"
|
||||
|
||||
gh pr create \
|
||||
--title "docs: link @${GITHUB_USERNAME} to Discord" \
|
||||
--body "Adds @${GITHUB_USERNAME} (Discord \`${DISCORD_ID}\`) to \`contributors.yml\` for bounty XP tracking.
|
||||
|
||||
Closes #${ISSUE_NUMBER}" \
|
||||
--base main \
|
||||
--head "$BRANCH" \
|
||||
--label "link-discord"
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Notify on issue
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const username = process.env.GITHUB_USERNAME;
|
||||
const issueNumber = parseInt(process.env.ISSUE_NUMBER);
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issueNumber,
|
||||
body: `A PR has been created to link your account. A maintainer will merge it shortly — once merged, you'll receive XP and Discord pings when your bounty PRs are merged.`
|
||||
});
|
||||
@@ -35,6 +35,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
SINCE_DATE: ${{ github.event.inputs.since_date || '' }}
|
||||
|
||||
@@ -68,7 +68,6 @@ temp/
|
||||
exports/*
|
||||
|
||||
.claude/settings.local.json
|
||||
.claude/skills/ship-it/
|
||||
|
||||
.venv
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
.PHONY: lint format check test install-hooks help frontend-install frontend-dev frontend-build
|
||||
.PHONY: lint format check test test-tools test-live test-all install-hooks help frontend-install frontend-dev frontend-build
|
||||
|
||||
# ── Ensure uv is findable in Git Bash on Windows ──────────────────────────────
|
||||
# uv installs to ~/.local/bin on Windows/Linux/macOS. Git Bash may not include
|
||||
# this in PATH by default, so we prepend it here.
|
||||
export PATH := $(HOME)/.local/bin:$(PATH)
|
||||
|
||||
# ── Targets ───────────────────────────────────────────────────────────────────
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \
|
||||
@@ -46,4 +53,4 @@ frontend-dev: ## Start frontend dev server
|
||||
cd core/frontend && npm run dev
|
||||
|
||||
frontend-build: ## Build frontend for production
|
||||
cd core/frontend && npm run build
|
||||
cd core/frontend && npm run build
|
||||
@@ -41,7 +41,9 @@ Generate a swarm of worker agents with a coding agent(queen) that control them.
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
[](https://www.youtube.com/watch?v=XDOG9fOaLjU)
|
||||
|
||||
https://github.com/user-attachments/assets/bf10edc3-06ba-48b6-98ba-d069b15fb69d
|
||||
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
perf: reduce subprocess spawning in quickstart scripts (#4427)
|
||||
|
||||
## Problem
|
||||
Windows process creation (CreateProcess) is 10-100x slower than Linux fork/exec.
|
||||
The quickstart scripts were spawning 4+ separate `uv run python -c "import X"`
|
||||
processes to verify imports, adding ~600ms overhead on Windows.
|
||||
|
||||
## Solution
|
||||
Consolidated all import checks into a single batch script that checks multiple
|
||||
modules in one subprocess call, reducing spawn overhead by ~75%.
|
||||
|
||||
## Changes
|
||||
- **New**: `scripts/check_requirements.py` - Batched import checker
|
||||
- **New**: `scripts/test_check_requirements.py` - Test suite
|
||||
- **New**: `scripts/benchmark_quickstart.ps1` - Performance benchmark tool
|
||||
- **Modified**: `quickstart.ps1` - Updated import verification (2 sections)
|
||||
- **Modified**: `quickstart.sh` - Updated import verification
|
||||
|
||||
## Performance Impact
|
||||
**Benchmark results on Windows:**
|
||||
- Before: ~19.8 seconds for import checks
|
||||
- After: ~4.9 seconds for import checks
|
||||
- **Improvement: 14.9 seconds saved (75.2% faster)**
|
||||
|
||||
## Testing
|
||||
- ✅ All functional tests pass (`scripts/test_check_requirements.py`)
|
||||
- ✅ Quickstart scripts work correctly on Windows
|
||||
- ✅ Error handling verified (invalid imports reported correctly)
|
||||
- ✅ Performance benchmark confirms 75%+ improvement
|
||||
|
||||
Fixes #4427
|
||||
@@ -1,27 +0,0 @@
|
||||
# Identity mapping: GitHub username -> Discord ID
|
||||
#
|
||||
# This file links GitHub accounts to Discord accounts for the
|
||||
# Integration Bounty Program. When a bounty PR is merged, the
|
||||
# GitHub Action uses this file to ping the contributor on Discord.
|
||||
#
|
||||
# HOW TO ADD YOURSELF:
|
||||
# Open a "Link Discord Account" issue:
|
||||
# https://github.com/aden-hive/hive/issues/new?template=link-discord.yml
|
||||
# A GitHub Action will automatically add your entry here.
|
||||
#
|
||||
# To find your Discord ID:
|
||||
# 1. Open Discord Settings > Advanced > Enable Developer Mode
|
||||
# 2. Right-click your name > Copy User ID
|
||||
#
|
||||
# Format:
|
||||
# - github: your-github-username
|
||||
# discord: "your-discord-id" # quotes required (it's a number)
|
||||
# name: Your Display Name # optional
|
||||
|
||||
contributors:
|
||||
# - github: example-user
|
||||
# discord: "123456789012345678"
|
||||
# name: Example User
|
||||
- github: TimothyZhang7
|
||||
discord: "408460790061072384"
|
||||
name: Timothy@Aden
|
||||
@@ -0,0 +1,583 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Antigravity authentication CLI.
|
||||
|
||||
Implements OAuth2 flow for Google's Antigravity Code Assist gateway.
|
||||
Credentials are stored in ~/.hive/antigravity-accounts.json.
|
||||
|
||||
Usage:
|
||||
python -m antigravity_auth auth account add
|
||||
python -m antigravity_auth auth account list
|
||||
python -m antigravity_auth auth account remove <email>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth endpoints
|
||||
_OAUTH_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Scopes for Antigravity/Cloud Code Assist
|
||||
_OAUTH_SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
]
|
||||
|
||||
# Credentials file path in ~/.hive/
|
||||
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
|
||||
# Default project ID
|
||||
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
|
||||
_DEFAULT_REDIRECT_PORT = 51121
|
||||
|
||||
# OAuth credentials fetched from the opencode-antigravity-auth project.
|
||||
# This project reverse-engineered and published the public OAuth credentials
|
||||
# for Google's Antigravity/Cloud Code Assist API.
|
||||
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
|
||||
_CREDENTIALS_URL = (
|
||||
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
|
||||
)
|
||||
|
||||
# Cached credentials fetched from public source
|
||||
_cached_client_id: str | None = None
|
||||
_cached_client_secret: str | None = None
|
||||
|
||||
|
||||
def _fetch_credentials_from_public_source() -> tuple[str | None, str | None]:
|
||||
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
|
||||
global _cached_client_id, _cached_client_secret
|
||||
if _cached_client_id and _cached_client_secret:
|
||||
return _cached_client_id, _cached_client_secret
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
_CREDENTIALS_URL, headers={"User-Agent": "Hive-Antigravity-Auth/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
content = resp.read().decode("utf-8")
|
||||
import re
|
||||
|
||||
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
|
||||
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
|
||||
if id_match:
|
||||
_cached_client_id = id_match.group(1)
|
||||
if secret_match:
|
||||
_cached_client_secret = secret_match.group(1)
|
||||
return _cached_client_id, _cached_client_secret
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to fetch credentials from public source: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_client_id() -> str:
|
||||
"""Get OAuth client ID from env, config, or public source."""
|
||||
env_id = os.environ.get("ANTIGRAVITY_CLIENT_ID")
|
||||
if env_id:
|
||||
return env_id
|
||||
|
||||
# Try hive config
|
||||
hive_cfg = Path.home() / ".hive" / "configuration.json"
|
||||
if hive_cfg.exists():
|
||||
try:
|
||||
with open(hive_cfg) as f:
|
||||
cfg = json.load(f)
|
||||
cfg_id = cfg.get("llm", {}).get("antigravity_client_id")
|
||||
if cfg_id:
|
||||
return cfg_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch from public source
|
||||
client_id, _ = _fetch_credentials_from_public_source()
|
||||
if client_id:
|
||||
return client_id
|
||||
|
||||
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
|
||||
|
||||
|
||||
def get_client_secret() -> str | None:
|
||||
"""Get OAuth client secret from env, config, or public source."""
|
||||
secret = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
|
||||
if secret:
|
||||
return secret
|
||||
|
||||
# Try to read from hive config
|
||||
hive_cfg = Path.home() / ".hive" / "configuration.json"
|
||||
if hive_cfg.exists():
|
||||
try:
|
||||
with open(hive_cfg) as f:
|
||||
cfg = json.load(f)
|
||||
secret = cfg.get("llm", {}).get("antigravity_client_secret")
|
||||
if secret:
|
||||
return secret
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch from public source (npm package on GitHub)
|
||||
_, secret = _fetch_credentials_from_public_source()
|
||||
return secret
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find an available local port."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
class OAuthCallbackHandler(BaseHTTPRequestHandler):
|
||||
"""Handle OAuth callback from browser."""
|
||||
|
||||
auth_code: str | None = None
|
||||
state: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
pass # Suppress default logging
|
||||
|
||||
def do_GET(self) -> None:
|
||||
parsed = urllib.parse.urlparse(self.path)
|
||||
|
||||
if parsed.path == "/oauth-callback":
|
||||
query = urllib.parse.parse_qs(parsed.query)
|
||||
|
||||
if "error" in query:
|
||||
self.error = query["error"][0]
|
||||
self._send_response("Authentication failed. You can close this window.")
|
||||
return
|
||||
|
||||
if "code" in query and "state" in query:
|
||||
OAuthCallbackHandler.auth_code = query["code"][0]
|
||||
OAuthCallbackHandler.state = query["state"][0]
|
||||
self._send_response(
|
||||
"Authentication successful! You can close this window "
|
||||
"and return to the terminal."
|
||||
)
|
||||
return
|
||||
|
||||
self._send_response("Waiting for authentication...")
|
||||
|
||||
def _send_response(self, message: str) -> None:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Antigravity Auth</title></head>
|
||||
<body style="font-family: system-ui; display: flex; align-items: center;
|
||||
justify-content: center; height: 100vh; margin: 0; background: #1a1a2e;
|
||||
color: #eee;">
|
||||
<div style="text-align: center;">
|
||||
<h2>{message}</h2>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
self.wfile.write(html.encode())
|
||||
|
||||
|
||||
def wait_for_callback(port: int, timeout: int = 300) -> tuple[str | None, str | None, str | None]:
|
||||
"""Start local server and wait for OAuth callback."""
|
||||
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
|
||||
server.timeout = 1
|
||||
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if OAuthCallbackHandler.auth_code:
|
||||
return (
|
||||
OAuthCallbackHandler.auth_code,
|
||||
OAuthCallbackHandler.state,
|
||||
OAuthCallbackHandler.error,
|
||||
)
|
||||
server.handle_request()
|
||||
|
||||
return None, None, "timeout"
|
||||
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
code: str, redirect_uri: str, client_id: str, client_secret: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Exchange authorization code for tokens."""
|
||||
data = {
|
||||
"code": code,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
body = urllib.parse.urlencode(data).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return json.loads(resp.read())
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_user_email(access_token: str) -> str | None:
|
||||
"""Get user email from Google API."""
|
||||
req = urllib.request.Request(
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data.get("email")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_accounts() -> dict[str, Any]:
|
||||
"""Load existing accounts from file."""
|
||||
if not _ACCOUNTS_FILE.exists():
|
||||
return {"schemaVersion": 4, "accounts": []}
|
||||
try:
|
||||
with open(_ACCOUNTS_FILE) as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {"schemaVersion": 4, "accounts": []}
|
||||
|
||||
|
||||
def save_accounts(data: dict[str, Any]) -> None:
|
||||
"""Save accounts to file."""
|
||||
_ACCOUNTS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(_ACCOUNTS_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
logger.info(f"Saved credentials to {_ACCOUNTS_FILE}")
|
||||
|
||||
|
||||
def validate_credentials(access_token: str, project_id: str = _DEFAULT_PROJECT_ID) -> bool:
|
||||
"""Test if credentials work by making a simple API call to Antigravity.
|
||||
|
||||
Returns True if credentials are valid, False otherwise.
|
||||
"""
|
||||
endpoint = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
body = {
|
||||
"project": project_id,
|
||||
"model": "gemini-3-flash",
|
||||
"request": {
|
||||
"contents": [{"role": "user", "parts": [{"text": "hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
},
|
||||
"requestType": "agent",
|
||||
"userAgent": "antigravity",
|
||||
"requestId": "validation-test",
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) Antigravity/1.18.3"
|
||||
),
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
}
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{endpoint}/v1internal:generateContent",
|
||||
data=json.dumps(body).encode("utf-8"),
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
json.loads(resp.read())
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def refresh_access_token(
|
||||
refresh_token: str, client_id: str, client_secret: str | None
|
||||
) -> dict | None:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
body = urllib.parse.urlencode(data).encode()
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return json.loads(resp.read())
|
||||
except Exception as e:
|
||||
logger.debug(f"Token refresh failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def cmd_account_add(args: argparse.Namespace) -> int:
|
||||
"""Add a new Antigravity account via OAuth2.
|
||||
|
||||
First checks if valid credentials already exist. If so, validates them
|
||||
and skips OAuth if they work. Otherwise, proceeds with OAuth flow.
|
||||
"""
|
||||
client_id = get_client_id()
|
||||
client_secret = get_client_secret()
|
||||
|
||||
# Check if credentials already exist
|
||||
accounts_data = load_accounts()
|
||||
accounts = accounts_data.get("accounts", [])
|
||||
|
||||
if accounts:
|
||||
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
|
||||
access_token = account.get("access")
|
||||
refresh_token_str = account.get("refresh", "")
|
||||
refresh_token = refresh_token_str.split("|")[0] if refresh_token_str else None
|
||||
project_id = (
|
||||
refresh_token_str.split("|")[1] if "|" in refresh_token_str else _DEFAULT_PROJECT_ID
|
||||
)
|
||||
email = account.get("email", "unknown")
|
||||
expires_ms = account.get("expires", 0)
|
||||
expires_at = expires_ms / 1000.0 if expires_ms else 0.0
|
||||
|
||||
# Check if token is expired or near expiry
|
||||
if access_token and expires_at and time.time() < expires_at - 60:
|
||||
# Token still valid, test it
|
||||
logger.info(f"Found existing credentials for: {email}")
|
||||
logger.info("Validating existing credentials...")
|
||||
if validate_credentials(access_token, project_id):
|
||||
logger.info("✓ Credentials valid! Skipping OAuth.")
|
||||
return 0
|
||||
else:
|
||||
logger.info("Credentials failed validation, refreshing...")
|
||||
elif refresh_token:
|
||||
logger.info(f"Found expired credentials for: {email}")
|
||||
logger.info("Attempting token refresh...")
|
||||
|
||||
tokens = refresh_access_token(refresh_token, client_id, client_secret)
|
||||
if tokens:
|
||||
new_access = tokens.get("access_token")
|
||||
expires_in = tokens.get("expires_in", 3600)
|
||||
if new_access:
|
||||
# Update the account
|
||||
account["access"] = new_access
|
||||
account["expires"] = int((time.time() + expires_in) * 1000)
|
||||
accounts_data["last_refresh"] = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ", time.gmtime()
|
||||
)
|
||||
save_accounts(accounts_data)
|
||||
|
||||
# Validate the refreshed token
|
||||
logger.info("Validating refreshed credentials...")
|
||||
if validate_credentials(new_access, project_id):
|
||||
logger.info("✓ Credentials refreshed and validated!")
|
||||
return 0
|
||||
else:
|
||||
logger.info("Refreshed token failed validation, proceeding with OAuth...")
|
||||
else:
|
||||
logger.info("Token refresh failed, proceeding with OAuth...")
|
||||
|
||||
# No valid credentials, proceed with OAuth
|
||||
if not client_secret:
|
||||
logger.warning(
|
||||
"No client secret configured. Token refresh may fail.\n"
|
||||
"Set ANTIGRAVITY_CLIENT_SECRET env var or add "
|
||||
"'antigravity_client_secret' to ~/.hive/configuration.json"
|
||||
)
|
||||
|
||||
# Use fixed port and path matching Google's expected OAuth redirect URI
|
||||
port = _DEFAULT_REDIRECT_PORT
|
||||
redirect_uri = f"http://localhost:{port}/oauth-callback"
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(16)
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(_OAUTH_SCOPES),
|
||||
"state": state,
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
auth_url = f"{_OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
logger.info("Opening browser for authentication...")
|
||||
logger.info(f"If the browser doesn't open, visit: {auth_url}\n")
|
||||
|
||||
# Open browser
|
||||
webbrowser.open(auth_url)
|
||||
|
||||
# Wait for callback
|
||||
logger.info(f"Listening for callback on port {port}...")
|
||||
code, received_state, error = wait_for_callback(port)
|
||||
|
||||
if error:
|
||||
logger.error(f"Authentication failed: {error}")
|
||||
return 1
|
||||
|
||||
if not code:
|
||||
logger.error("No authorization code received")
|
||||
return 1
|
||||
|
||||
if received_state != state:
|
||||
logger.error("State mismatch - possible CSRF attack")
|
||||
return 1
|
||||
|
||||
# Exchange code for tokens
|
||||
logger.info("Exchanging authorization code for tokens...")
|
||||
tokens = exchange_code_for_tokens(code, redirect_uri, client_id, client_secret)
|
||||
|
||||
if not tokens:
|
||||
return 1
|
||||
|
||||
access_token = tokens.get("access_token")
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
expires_in = tokens.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
logger.error("No access token in response")
|
||||
return 1
|
||||
|
||||
# Get user email
|
||||
email = get_user_email(access_token)
|
||||
if email:
|
||||
logger.info(f"Authenticated as: {email}")
|
||||
|
||||
# Load existing accounts and add/update
|
||||
accounts_data = load_accounts()
|
||||
accounts = accounts_data.get("accounts", [])
|
||||
|
||||
# Build new account entry (V4 schema)
|
||||
expires_ms = int((time.time() + expires_in) * 1000)
|
||||
refresh_entry = f"{refresh_token}|{_DEFAULT_PROJECT_ID}"
|
||||
|
||||
new_account = {
|
||||
"access": access_token,
|
||||
"refresh": refresh_entry,
|
||||
"expires": expires_ms,
|
||||
"email": email,
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# Update existing account or add new one
|
||||
existing_idx = next((i for i, a in enumerate(accounts) if a.get("email") == email), None)
|
||||
if existing_idx is not None:
|
||||
accounts[existing_idx] = new_account
|
||||
logger.info(f"Updated existing account: {email}")
|
||||
else:
|
||||
accounts.append(new_account)
|
||||
logger.info(f"Added new account: {email}")
|
||||
|
||||
accounts_data["accounts"] = accounts
|
||||
accounts_data["schemaVersion"] = 4
|
||||
accounts_data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
|
||||
save_accounts(accounts_data)
|
||||
logger.info("\n✓ Authentication complete!")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_account_list(args: argparse.Namespace) -> int:
|
||||
"""List all stored accounts."""
|
||||
data = load_accounts()
|
||||
accounts = data.get("accounts", [])
|
||||
|
||||
if not accounts:
|
||||
logger.info("No accounts configured.")
|
||||
logger.info("Run 'antigravity auth account add' to add one.")
|
||||
return 0
|
||||
|
||||
logger.info("Configured accounts:\n")
|
||||
for i, account in enumerate(accounts, 1):
|
||||
email = account.get("email", "unknown")
|
||||
enabled = "enabled" if account.get("enabled", True) else "disabled"
|
||||
logger.info(f" {i}. {email} ({enabled})")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_account_remove(args: argparse.Namespace) -> int:
|
||||
"""Remove an account by email."""
|
||||
email = args.email
|
||||
data = load_accounts()
|
||||
accounts = data.get("accounts", [])
|
||||
|
||||
original_len = len(accounts)
|
||||
accounts = [a for a in accounts if a.get("email") != email]
|
||||
|
||||
if len(accounts) == original_len:
|
||||
logger.error(f"No account found with email: {email}")
|
||||
return 1
|
||||
|
||||
data["accounts"] = accounts
|
||||
save_accounts(data)
|
||||
logger.info(f"Removed account: {email}")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Antigravity authentication CLI",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||
|
||||
# auth account add
|
||||
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
|
||||
auth_subparsers = auth_parser.add_subparsers(dest="auth_command")
|
||||
|
||||
account_parser = auth_subparsers.add_parser("account", help="Account management")
|
||||
account_subparsers = account_parser.add_subparsers(dest="account_command")
|
||||
|
||||
add_parser = account_subparsers.add_parser("add", help="Add a new account via OAuth2")
|
||||
add_parser.set_defaults(func=cmd_account_add)
|
||||
|
||||
list_parser = account_subparsers.add_parser("list", help="List configured accounts")
|
||||
list_parser.set_defaults(func=cmd_account_list)
|
||||
|
||||
remove_parser = account_subparsers.add_parser("remove", help="Remove an account")
|
||||
remove_parser.add_argument("email", help="Email of account to remove")
|
||||
remove_parser.set_defaults(func=cmd_account_remove)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
return args.func(args)
|
||||
|
||||
parser.print_help()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -23,25 +23,56 @@ class AgentEntry:
|
||||
last_active: str | None = None
|
||||
|
||||
|
||||
def _get_last_active(agent_name: str) -> str | None:
|
||||
"""Return the most recent updated_at timestamp across all sessions."""
|
||||
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
|
||||
if not sessions_dir.exists():
|
||||
return None
|
||||
def _get_last_active(agent_path: Path) -> str | None:
|
||||
"""Return the most recent updated_at timestamp across all sessions.
|
||||
|
||||
Checks both worker sessions (``~/.hive/agents/{name}/sessions/``) and
|
||||
queen sessions (``~/.hive/queen/session/``) whose ``meta.json`` references
|
||||
the same *agent_path*.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
agent_name = agent_path.name
|
||||
latest: str | None = None
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
|
||||
continue
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamps", {}).get("updated_at")
|
||||
if ts and (latest is None or ts > latest):
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 1. Worker sessions
|
||||
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
|
||||
if sessions_dir.exists():
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
|
||||
continue
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamps", {}).get("updated_at")
|
||||
if ts and (latest is None or ts > latest):
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2. Queen sessions
|
||||
queen_sessions_dir = Path.home() / ".hive" / "queen" / "session"
|
||||
if queen_sessions_dir.exists():
|
||||
resolved = agent_path.resolve()
|
||||
for d in queen_sessions_dir.iterdir():
|
||||
if not d.is_dir():
|
||||
continue
|
||||
meta_file = d / "meta.json"
|
||||
if not meta_file.exists():
|
||||
continue
|
||||
try:
|
||||
meta = json.loads(meta_file.read_text(encoding="utf-8"))
|
||||
stored = meta.get("agent_path")
|
||||
if not stored or Path(stored).resolve() != resolved:
|
||||
continue
|
||||
ts = datetime.fromtimestamp(d.stat().st_mtime).isoformat()
|
||||
if latest is None or ts > latest:
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return latest
|
||||
|
||||
|
||||
@@ -169,7 +200,7 @@ def discover_agents() -> dict[str, list[AgentEntry]]:
|
||||
node_count=node_count,
|
||||
tool_count=tool_count,
|
||||
tags=tags,
|
||||
last_active=_get_last_active(path.name),
|
||||
last_active=_get_last_active(path),
|
||||
)
|
||||
)
|
||||
if entries:
|
||||
|
||||
@@ -702,6 +702,15 @@ stop_worker() to return to STAGING phase.
|
||||
_queen_behavior_always = """
|
||||
# Behavior
|
||||
|
||||
## Images attached by the user
|
||||
|
||||
Users can attach images directly to their chat messages. When you see an \
|
||||
image in the conversation, analyze it using your native vision capability — \
|
||||
do NOT say you cannot see images or that you lack access to files. The image \
|
||||
is embedded in the message; no tool call is needed to view it. Describe what \
|
||||
you see, answer questions about it, and use the visual content to inform your \
|
||||
response just as you would text.
|
||||
|
||||
## CRITICAL RULE — ask_user / ask_user_multiple
|
||||
|
||||
Every response that ends with a question, a prompt, or expects user \
|
||||
@@ -1144,6 +1153,8 @@ Batch your response — do not call run_agent_with_input() once per trigger.
|
||||
config since last run), skip it and inform the user.
|
||||
- Never disable a trigger without telling the user. Use remove_trigger() only \
|
||||
when explicitly asked or when the trigger is clearly obsolete.
|
||||
- When the user asks to remove or disable a trigger, you MUST call remove_trigger(trigger_id). \
|
||||
Never just say "it's removed" without actually calling the tool.
|
||||
"""
|
||||
|
||||
# -- Backward-compatible composed versions (used by queen_node.system_prompt default) --
|
||||
|
||||
@@ -150,7 +150,7 @@ Call all three subagents in a single response to run them in parallel:
|
||||
|
||||
## GCU Anti-Patterns
|
||||
|
||||
- Using `browser_screenshot` to read text (use `browser_snapshot`)
|
||||
- Using `browser_screenshot` to read text (use `browser_snapshot` instead; screenshots are for visual context only)
|
||||
- Re-navigating after scrolling (resets scroll position)
|
||||
- Attempting login on auth walls
|
||||
- Forgetting `target_id` in multi-tab scenarios
|
||||
|
||||
@@ -0,0 +1,286 @@
|
||||
"""Worker per-run digest (run diary).
|
||||
|
||||
Storage layout:
|
||||
~/.hive/agents/{agent_name}/runs/{run_id}/digest.md
|
||||
|
||||
Each completed or failed worker run gets one digest file. The queen reads
|
||||
these via get_worker_status(focus='diary') before digging into live runtime
|
||||
logs — the diary is a cheap, persistent record that survives across sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DIGEST_SYSTEM = """\
|
||||
You maintain run digests for a worker agent.
|
||||
A run digest is a concise, factual record of a single task execution.
|
||||
|
||||
Write 3-6 sentences covering:
|
||||
- What the worker was asked to do (the task/goal)
|
||||
- What approach it took and what tools it used
|
||||
- What the outcome was (success, partial, or failure — and why if relevant)
|
||||
- Any notable issues, retries, or escalations to the queen
|
||||
|
||||
Write in third person past tense. Be direct and specific.
|
||||
Omit routine tool invocations unless the result matters.
|
||||
Output only the digest prose — no headings, no code fences.
|
||||
"""
|
||||
|
||||
|
||||
def _worker_runs_dir(agent_name: str) -> Path:
|
||||
return Path.home() / ".hive" / "agents" / agent_name / "runs"
|
||||
|
||||
|
||||
def digest_path(agent_name: str, run_id: str) -> Path:
|
||||
return _worker_runs_dir(agent_name) / run_id / "digest.md"
|
||||
|
||||
|
||||
def _collect_run_events(bus: EventBus, run_id: str, limit: int = 2000) -> list[AgentEvent]:
|
||||
"""Collect all events belonging to *run_id* from the bus history.
|
||||
|
||||
Strategy: find the EXECUTION_STARTED event that carries ``run_id``,
|
||||
extract its ``execution_id``, then query the bus by that execution_id.
|
||||
This works because TOOL_CALL_*, EDGE_TRAVERSED, NODE_STALLED etc. carry
|
||||
execution_id but not run_id.
|
||||
|
||||
Falls back to a full-scan run_id filter when EXECUTION_STARTED is not
|
||||
found (e.g. bus was rotated).
|
||||
"""
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
# Pass 1: find execution_id via EXECUTION_STARTED with matching run_id
|
||||
started = bus.get_history(event_type=EventType.EXECUTION_STARTED, limit=limit)
|
||||
exec_id: str | None = None
|
||||
for e in started:
|
||||
if getattr(e, "run_id", None) == run_id and e.execution_id:
|
||||
exec_id = e.execution_id
|
||||
break
|
||||
|
||||
if exec_id:
|
||||
return bus.get_history(execution_id=exec_id, limit=limit)
|
||||
|
||||
# Fallback: scan all events and match by run_id attribute
|
||||
return [e for e in bus.get_history(limit=limit) if getattr(e, "run_id", None) == run_id]
|
||||
|
||||
|
||||
def _build_run_context(
|
||||
events: list[AgentEvent],
|
||||
outcome_event: AgentEvent | None,
|
||||
) -> str:
|
||||
"""Assemble a plain-text run context string for the digest LLM call."""
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
# Reverse so events are in chronological order
|
||||
events_chron = list(reversed(events))
|
||||
|
||||
lines: list[str] = []
|
||||
|
||||
# Task input from EXECUTION_STARTED
|
||||
started = [e for e in events_chron if e.type == EventType.EXECUTION_STARTED]
|
||||
if started:
|
||||
inp = started[0].data.get("input", {})
|
||||
if inp:
|
||||
lines.append(f"Task input: {str(inp)[:400]}")
|
||||
|
||||
# Duration (elapsed so far if no outcome yet)
|
||||
ref_ts = outcome_event.timestamp if outcome_event else datetime.utcnow()
|
||||
if started:
|
||||
elapsed = (ref_ts - started[0].timestamp).total_seconds()
|
||||
m, s = divmod(int(elapsed), 60)
|
||||
lines.append(f"Duration so far: {m}m {s}s" if m else f"Duration so far: {s}s")
|
||||
|
||||
# Outcome
|
||||
if outcome_event is None:
|
||||
lines.append("Status: still running (mid-run snapshot)")
|
||||
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
|
||||
out = outcome_event.data.get("output", {})
|
||||
out_str = f"Outcome: completed. Output: {str(out)[:300]}"
|
||||
lines.append(out_str if out else "Outcome: completed.")
|
||||
else:
|
||||
err = outcome_event.data.get("error", "")
|
||||
lines.append(f"Outcome: failed. Error: {str(err)[:300]}" if err else "Outcome: failed.")
|
||||
|
||||
# Node path (edge traversals)
|
||||
edges = [e for e in events_chron if e.type == EventType.EDGE_TRAVERSED]
|
||||
if edges:
|
||||
parts = [
|
||||
f"{e.data.get('source_node', '?')}->{e.data.get('target_node', '?')}"
|
||||
for e in edges[-20:]
|
||||
]
|
||||
lines.append(f"Node path: {', '.join(parts)}")
|
||||
|
||||
# Tools used
|
||||
tool_events = [e for e in events_chron if e.type == EventType.TOOL_CALL_COMPLETED]
|
||||
if tool_events:
|
||||
names = [e.data.get("tool_name", "?") for e in tool_events]
|
||||
counts = Counter(names)
|
||||
summary = ", ".join(f"{name}×{n}" if n > 1 else name for name, n in counts.most_common())
|
||||
lines.append(f"Tools used: {summary}")
|
||||
# Note any tool errors
|
||||
errors = [e for e in tool_events if e.data.get("is_error")]
|
||||
if errors:
|
||||
err_names = Counter(e.data.get("tool_name", "?") for e in errors)
|
||||
lines.append(f"Tool errors: {dict(err_names)}")
|
||||
|
||||
# Issues
|
||||
issue_map = {
|
||||
EventType.NODE_STALLED: "stall",
|
||||
EventType.NODE_TOOL_DOOM_LOOP: "doom loop",
|
||||
EventType.CONSTRAINT_VIOLATION: "constraint violation",
|
||||
EventType.NODE_RETRY: "retry",
|
||||
}
|
||||
issue_parts: list[str] = []
|
||||
for evt_type, label in issue_map.items():
|
||||
n = sum(1 for e in events_chron if e.type == evt_type)
|
||||
if n:
|
||||
issue_parts.append(f"{n} {label}(s)")
|
||||
if issue_parts:
|
||||
lines.append(f"Issues: {', '.join(issue_parts)}")
|
||||
|
||||
# Escalations to queen
|
||||
escalations = [e for e in events_chron if e.type == EventType.ESCALATION_REQUESTED]
|
||||
if escalations:
|
||||
lines.append(f"Escalations to queen: {len(escalations)}")
|
||||
|
||||
# Final LLM output snippet (last LLM_TEXT_DELTA snapshot)
|
||||
text_events = [e for e in reversed(events_chron) if e.type == EventType.LLM_TEXT_DELTA]
|
||||
if text_events:
|
||||
snapshot = text_events[0].data.get("snapshot", "") or ""
|
||||
if snapshot:
|
||||
lines.append(f"Final LLM output: {snapshot[-400:].strip()}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def consolidate_worker_run(
|
||||
agent_name: str,
|
||||
run_id: str,
|
||||
outcome_event: AgentEvent | None,
|
||||
bus: EventBus,
|
||||
llm: Any,
|
||||
) -> None:
|
||||
"""Write (or overwrite) the digest for a worker run.
|
||||
|
||||
Called fire-and-forget either:
|
||||
- After EXECUTION_COMPLETED / EXECUTION_FAILED (outcome_event set, final write)
|
||||
- Periodically during a run on a cooldown timer (outcome_event=None, mid-run snapshot)
|
||||
|
||||
The digest file is always overwritten so each call produces the freshest view.
|
||||
The final completion/failure call supersedes any mid-run snapshot.
|
||||
|
||||
Args:
|
||||
agent_name: Worker agent directory name (determines storage path).
|
||||
run_id: The run ID.
|
||||
outcome_event: EXECUTION_COMPLETED or EXECUTION_FAILED event, or None for
|
||||
a mid-run snapshot.
|
||||
bus: The session EventBus (shared queen + worker).
|
||||
llm: LLMProvider with an acomplete() method.
|
||||
"""
|
||||
try:
|
||||
events = _collect_run_events(bus, run_id)
|
||||
run_context = _build_run_context(events, outcome_event)
|
||||
if not run_context:
|
||||
logger.debug("worker_memory: no events for run %s, skipping digest", run_id)
|
||||
return
|
||||
|
||||
is_final = outcome_event is not None
|
||||
logger.info(
|
||||
"worker_memory: generating %s digest for run %s ...",
|
||||
"final" if is_final else "mid-run",
|
||||
run_id,
|
||||
)
|
||||
|
||||
from framework.agents.queen.config import default_config
|
||||
|
||||
resp = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": run_context}],
|
||||
system=_DIGEST_SYSTEM,
|
||||
max_tokens=min(default_config.max_tokens, 512),
|
||||
)
|
||||
digest_text = (resp.content or "").strip()
|
||||
if not digest_text:
|
||||
logger.warning("worker_memory: LLM returned empty digest for run %s", run_id)
|
||||
return
|
||||
|
||||
path = digest_path(agent_name, run_id)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
ts = (outcome_event.timestamp if outcome_event else datetime.utcnow()).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
if outcome_event is None:
|
||||
status = "running"
|
||||
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
|
||||
status = "completed"
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
path.write_text(
|
||||
f"# {run_id}\n\n**{ts}** | {status}\n\n{digest_text}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(
|
||||
"worker_memory: %s digest written for run %s (%d chars)",
|
||||
status,
|
||||
run_id,
|
||||
len(digest_text),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
logger.exception("worker_memory: digest failed for run %s", run_id)
|
||||
# Persist the error so it's findable without log access
|
||||
error_path = _worker_runs_dir(agent_name) / run_id / "digest_error.txt"
|
||||
try:
|
||||
error_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
error_path.write_text(
|
||||
f"run_id: {run_id}\ntime: {datetime.now().isoformat()}\n\n{tb}",
|
||||
encoding="utf-8",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def read_recent_digests(agent_name: str, max_runs: int = 5) -> list[tuple[str, str]]:
|
||||
"""Return recent run digests as [(run_id, content), ...], newest first.
|
||||
|
||||
Args:
|
||||
agent_name: Worker agent directory name.
|
||||
max_runs: Maximum number of digests to return.
|
||||
|
||||
Returns:
|
||||
List of (run_id, digest_content) tuples, ordered newest first.
|
||||
"""
|
||||
runs_dir = _worker_runs_dir(agent_name)
|
||||
if not runs_dir.exists():
|
||||
return []
|
||||
|
||||
digest_files = sorted(
|
||||
runs_dir.glob("*/digest.md"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)[:max_runs]
|
||||
|
||||
result: list[tuple[str, str]] = []
|
||||
for f in digest_files:
|
||||
try:
|
||||
content = f.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
result.append((f.parent.name, content))
|
||||
except OSError:
|
||||
continue
|
||||
return result
|
||||
@@ -89,6 +89,16 @@ def main():
|
||||
|
||||
register_testing_commands(subparsers)
|
||||
|
||||
# Register skill commands (skill list, skill trust, ...)
|
||||
from framework.skills.cli import register_skill_commands
|
||||
|
||||
register_skill_commands(subparsers)
|
||||
|
||||
# Register debugger commands (debugger)
|
||||
from framework.debugger.cli import register_debugger_commands
|
||||
|
||||
register_debugger_commands(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
|
||||
+251
-2
@@ -51,16 +51,167 @@ def get_preferred_model() -> str:
|
||||
"""Return the user's preferred LLM model string (e.g. 'anthropic/claude-sonnet-4-20250514')."""
|
||||
llm = get_hive_config().get("llm", {})
|
||||
if llm.get("provider") and llm.get("model"):
|
||||
return f"{llm['provider']}/{llm['model']}"
|
||||
provider = str(llm["provider"])
|
||||
model = str(llm["model"]).strip()
|
||||
# OpenRouter quickstart stores raw model IDs; tolerate pasted "openrouter/<id>" too.
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return "anthropic/claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
def get_preferred_worker_model() -> str | None:
|
||||
"""Return the user's preferred worker LLM model, or None if not configured.
|
||||
|
||||
Reads from the ``worker_llm`` section of ~/.hive/configuration.json.
|
||||
Returns None when no worker-specific model is set, so callers can
|
||||
fall back to the default (queen) model via ``get_preferred_model()``.
|
||||
"""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm.get("provider") and worker_llm.get("model"):
|
||||
provider = str(worker_llm["provider"])
|
||||
model = str(worker_llm["model"]).strip()
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_api_key() -> str | None:
|
||||
"""Return the API key for the worker LLM, falling back to the default key."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_key()
|
||||
|
||||
# Worker-specific subscription / env var
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_claude_code_token
|
||||
|
||||
token = get_claude_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_codex_token
|
||||
|
||||
token = get_codex_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_kimi_code_token
|
||||
|
||||
token = get_kimi_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_antigravity_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_antigravity_token
|
||||
|
||||
token = get_antigravity_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
api_key_env_var = worker_llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
|
||||
# Fall back to default key
|
||||
return get_api_key()
|
||||
|
||||
|
||||
def get_worker_api_base() -> str | None:
|
||||
"""Return the api_base for the worker LLM, falling back to the default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_base()
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
return "https://chatgpt.com/backend-api/codex"
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
return "https://api.kimi.com/coding"
|
||||
if worker_llm.get("use_antigravity_subscription"):
|
||||
# Antigravity uses AntigravityProvider directly — no api_base needed.
|
||||
return None
|
||||
if worker_llm.get("api_base"):
|
||||
return worker_llm["api_base"]
|
||||
if str(worker_llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_llm_extra_kwargs() -> dict[str, Any]:
|
||||
"""Return extra kwargs for the worker LLM provider."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_llm_extra_kwargs()
|
||||
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
return {
|
||||
"extra_headers": {"authorization": f"Bearer {api_key}"},
|
||||
}
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
try:
|
||||
from framework.runner.runner import get_codex_account_id
|
||||
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
headers["ChatGPT-Account-Id"] = account_id
|
||||
except ImportError:
|
||||
pass
|
||||
return {
|
||||
"extra_headers": headers,
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def get_worker_max_tokens() -> int:
|
||||
"""Return max_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_tokens" in worker_llm:
|
||||
return worker_llm["max_tokens"]
|
||||
return get_max_tokens()
|
||||
|
||||
|
||||
def get_worker_max_context_tokens() -> int:
|
||||
"""Return max_context_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_context_tokens" in worker_llm:
|
||||
return worker_llm["max_context_tokens"]
|
||||
return get_max_context_tokens()
|
||||
|
||||
|
||||
def get_max_tokens() -> int:
|
||||
"""Return the configured max_tokens, falling back to DEFAULT_MAX_TOKENS."""
|
||||
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
|
||||
|
||||
DEFAULT_MAX_CONTEXT_TOKENS = 32_000
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def get_max_context_tokens() -> int:
|
||||
@@ -113,6 +264,17 @@ def get_api_key() -> str | None:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Antigravity subscription: read OAuth token from accounts JSON
|
||||
if llm.get("use_antigravity_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_antigravity_token
|
||||
|
||||
token = get_antigravity_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Standard env-var path (covers ZAI Code and all API-key providers)
|
||||
api_key_env_var = llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
@@ -120,6 +282,86 @@ def get_api_key() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
# OAuth credentials for Antigravity are fetched from the opencode-antigravity-auth project.
|
||||
# This project reverse-engineered and published the public OAuth credentials
|
||||
# for Google's Antigravity/Cloud Code Assist API.
|
||||
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
|
||||
_ANTIGRAVITY_CREDENTIALS_URL = (
|
||||
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
|
||||
)
|
||||
_antigravity_credentials_cache: tuple[str | None, str | None] = (None, None)
|
||||
|
||||
|
||||
def _fetch_antigravity_credentials() -> tuple[str | None, str | None]:
|
||||
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
|
||||
global _antigravity_credentials_cache
|
||||
if _antigravity_credentials_cache[0] and _antigravity_credentials_cache[1]:
|
||||
return _antigravity_credentials_cache
|
||||
|
||||
import re
|
||||
import urllib.request
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
_ANTIGRAVITY_CREDENTIALS_URL, headers={"User-Agent": "Hive/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
content = resp.read().decode("utf-8")
|
||||
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
|
||||
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
|
||||
client_id = id_match.group(1) if id_match else None
|
||||
client_secret = secret_match.group(1) if secret_match else None
|
||||
if client_id and client_secret:
|
||||
_antigravity_credentials_cache = (client_id, client_secret)
|
||||
return client_id, client_secret
|
||||
except Exception as e:
|
||||
logger.debug("Failed to fetch Antigravity credentials from public source: %s", e)
|
||||
return None, None
|
||||
|
||||
|
||||
def get_antigravity_client_id() -> str:
|
||||
"""Return the Antigravity OAuth application client ID.
|
||||
|
||||
Checked in order:
|
||||
1. ``ANTIGRAVITY_CLIENT_ID`` environment variable
|
||||
2. ``llm.antigravity_client_id`` in ~/.hive/configuration.json
|
||||
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
|
||||
"""
|
||||
env = os.environ.get("ANTIGRAVITY_CLIENT_ID")
|
||||
if env:
|
||||
return env
|
||||
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_id")
|
||||
if cfg_val:
|
||||
return cfg_val
|
||||
# Fetch from public source
|
||||
client_id, _ = _fetch_antigravity_credentials()
|
||||
if client_id:
|
||||
return client_id
|
||||
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
|
||||
|
||||
|
||||
def get_antigravity_client_secret() -> str | None:
|
||||
"""Return the Antigravity OAuth client secret.
|
||||
|
||||
Checked in order:
|
||||
1. ``ANTIGRAVITY_CLIENT_SECRET`` environment variable
|
||||
2. ``llm.antigravity_client_secret`` in ~/.hive/configuration.json
|
||||
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
|
||||
|
||||
Returns None when not found — token refresh will be skipped and
|
||||
the caller must use whatever access token is already available.
|
||||
"""
|
||||
env = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
|
||||
if env:
|
||||
return env
|
||||
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_secret") or None
|
||||
if cfg_val:
|
||||
return cfg_val
|
||||
# Fetch from public source
|
||||
_, secret = _fetch_antigravity_credentials()
|
||||
return secret
|
||||
|
||||
|
||||
def get_gcu_enabled() -> bool:
|
||||
"""Return whether GCU (browser automation) is enabled in user config."""
|
||||
return get_hive_config().get("gcu_enabled", True)
|
||||
@@ -142,7 +384,14 @@ def get_api_base() -> str | None:
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
return "https://api.kimi.com/coding"
|
||||
return llm.get("api_base")
|
||||
if llm.get("use_antigravity_subscription"):
|
||||
# Antigravity uses AntigravityProvider directly — no api_base needed.
|
||||
return None
|
||||
if llm.get("api_base"):
|
||||
return llm["api_base"]
|
||||
if str(llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_extra_kwargs() -> dict[str, Any]:
|
||||
|
||||
@@ -51,6 +51,16 @@ def ensure_credential_key_env() -> None:
|
||||
if found and value:
|
||||
os.environ[var_name] = value
|
||||
logger.debug("Loaded %s from shell config", var_name)
|
||||
# Also load the currently configured LLM env var even if it's not in CREDENTIAL_SPECS.
|
||||
# This keeps quickstart-written keys available to fresh processes on Unix shells.
|
||||
from framework.config import get_hive_config
|
||||
|
||||
llm_env_var = str(get_hive_config().get("llm", {}).get("api_key_env_var", "")).strip()
|
||||
if llm_env_var and not os.environ.get(llm_env_var):
|
||||
found, value = check_env_var_in_shell_config(llm_env_var)
|
||||
if found and value:
|
||||
os.environ[llm_env_var] = value
|
||||
logger.debug("Loaded configured LLM env var %s from shell config", llm_env_var)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
"""CLI command for the LLM debug log viewer."""
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_SCRIPT = Path(__file__).resolve().parents[3] / "scripts" / "llm_debug_log_visualizer.py"
|
||||
|
||||
|
||||
def register_debugger_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
"""Register the ``hive debugger`` command."""
|
||||
parser = subparsers.add_parser(
|
||||
"debugger",
|
||||
help="Open the LLM debug log viewer",
|
||||
description=(
|
||||
"Start a local server that lets you browse LLM debug sessions "
|
||||
"recorded in ~/.hive/llm_logs. Sessions are loaded on demand so "
|
||||
"the browser stays responsive."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--session",
|
||||
help="Execution ID to select initially.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port for the local server (0 = auto-pick a free port).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
help="Directory containing JSONL log files (default: ~/.hive/llm_logs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit-files",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of newest log files to scan (default: 200).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
help="Write a static HTML file instead of starting a server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-open",
|
||||
action="store_true",
|
||||
help="Start the server but do not open a browser.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-tests",
|
||||
action="store_true",
|
||||
help="Show test/mock sessions (hidden by default).",
|
||||
)
|
||||
parser.set_defaults(func=cmd_debugger)
|
||||
|
||||
|
||||
def cmd_debugger(args: argparse.Namespace) -> int:
|
||||
"""Launch the LLM debug log visualizer."""
|
||||
cmd: list[str] = [sys.executable, str(_SCRIPT)]
|
||||
if args.session:
|
||||
cmd += ["--session", args.session]
|
||||
if args.port:
|
||||
cmd += ["--port", str(args.port)]
|
||||
if args.logs_dir:
|
||||
cmd += ["--logs-dir", args.logs_dir]
|
||||
if args.limit_files is not None:
|
||||
cmd += ["--limit-files", str(args.limit_files)]
|
||||
if args.output:
|
||||
cmd += ["--output", args.output]
|
||||
if args.no_open:
|
||||
cmd.append("--no-open")
|
||||
if args.include_tests:
|
||||
cmd.append("--include-tests")
|
||||
return subprocess.call(cmd)
|
||||
@@ -33,10 +33,20 @@ class Message:
|
||||
is_transition_marker: bool = False
|
||||
# True when this message is real human input (from /chat), not a system prompt
|
||||
is_client_input: bool = False
|
||||
# Optional image content blocks (e.g. from browser_screenshot)
|
||||
image_content: list[dict[str, Any]] | None = None
|
||||
# True when message contains an activated skill body (AS-10: never prune)
|
||||
is_skill_content: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
if self.role == "user":
|
||||
if self.image_content:
|
||||
blocks: list[dict[str, Any]] = []
|
||||
if self.content:
|
||||
blocks.append({"type": "text", "text": self.content})
|
||||
blocks.extend(self.image_content)
|
||||
return {"role": "user", "content": blocks}
|
||||
return {"role": "user", "content": self.content}
|
||||
|
||||
if self.role == "assistant":
|
||||
@@ -47,6 +57,15 @@ class Message:
|
||||
|
||||
# role == "tool"
|
||||
content = f"ERROR: {self.content}" if self.is_error else self.content
|
||||
if self.image_content:
|
||||
# Multimodal tool result: text + image content blocks
|
||||
blocks: list[dict[str, Any]] = [{"type": "text", "text": content}]
|
||||
blocks.extend(self.image_content)
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_use_id,
|
||||
"content": blocks,
|
||||
}
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_use_id,
|
||||
@@ -72,6 +91,8 @@ class Message:
|
||||
d["is_transition_marker"] = self.is_transition_marker
|
||||
if self.is_client_input:
|
||||
d["is_client_input"] = self.is_client_input
|
||||
if self.image_content is not None:
|
||||
d["image_content"] = self.image_content
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
@@ -87,6 +108,7 @@ class Message:
|
||||
phase_id=data.get("phase_id"),
|
||||
is_transition_marker=data.get("is_transition_marker", False),
|
||||
is_client_input=data.get("is_client_input", False),
|
||||
image_content=data.get("image_content"),
|
||||
)
|
||||
|
||||
|
||||
@@ -373,6 +395,7 @@ class NodeConversation:
|
||||
*,
|
||||
is_transition_marker: bool = False,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -381,6 +404,7 @@ class NodeConversation:
|
||||
phase_id=self._current_phase,
|
||||
is_transition_marker=is_transition_marker,
|
||||
is_client_input=is_client_input,
|
||||
image_content=image_content,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -409,6 +433,8 @@ class NodeConversation:
|
||||
tool_use_id: str,
|
||||
content: str,
|
||||
is_error: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
is_skill_content: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -417,6 +443,8 @@ class NodeConversation:
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
phase_id=self._current_phase,
|
||||
image_content=image_content,
|
||||
is_skill_content=is_skill_content,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -610,8 +638,15 @@ class NodeConversation:
|
||||
continue
|
||||
if msg.is_error:
|
||||
continue # never prune errors
|
||||
if msg.is_skill_content:
|
||||
continue # never prune activated skill instructions (AS-10)
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
# Tiny results (set_output acks, confirmations) — pruning
|
||||
# saves negligible space but makes the LLM think the call
|
||||
# failed, causing costly retries.
|
||||
if len(msg.content) < 100:
|
||||
continue
|
||||
|
||||
# Phase-aware: protect current phase messages
|
||||
if self._current_phase and msg.phase_id == self._current_phase:
|
||||
@@ -901,8 +936,7 @@ class NodeConversation:
|
||||
full_path = str((spill_path / conv_filename).resolve())
|
||||
ref_parts.append(
|
||||
f"[Previous conversation saved to '{full_path}'. "
|
||||
f"Use load_data('{conv_filename}'), read_file('{full_path}'), "
|
||||
f"or run_command('cat \"{full_path}\"') to review if needed.]"
|
||||
f"Use load_data('{conv_filename}') to review if needed.]"
|
||||
)
|
||||
elif not collapsed_msgs:
|
||||
ref_parts.append("[Previous freeform messages compacted.]")
|
||||
|
||||
@@ -14,6 +14,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
@@ -24,6 +25,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from framework.graph.conversation import ConversationStore, NodeConversation
|
||||
from framework.graph.node import NodeContext, NodeProtocol, NodeResult
|
||||
from framework.llm.capabilities import supports_image_tool_results
|
||||
from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
@@ -37,6 +39,56 @@ from framework.runtime.llm_debug_logger import log_llm_turn
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _describe_images_as_text(image_content: list[dict[str, Any]]) -> str | None:
|
||||
"""Describe images using the best available vision model.
|
||||
|
||||
Called when the queen's model lacks vision support. Tries vision-capable
|
||||
models in priority order based on available API keys and returns a bracketed
|
||||
description to inject into the message text, or None if no vision model is
|
||||
reachable.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
# Build content blocks: prompt + all images
|
||||
blocks: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"Describe the following image(s) concisely but with enough detail "
|
||||
"that a text-only AI assistant can understand the content and context."
|
||||
),
|
||||
}
|
||||
]
|
||||
blocks.extend(image_content)
|
||||
|
||||
# Ordered candidates based on available env vars
|
||||
candidates: list[str] = []
|
||||
if os.environ.get("OPENAI_API_KEY"):
|
||||
candidates.append("gpt-4o-mini")
|
||||
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||
candidates.append("claude-3-haiku-20240307")
|
||||
if os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"):
|
||||
candidates.append("gemini/gemini-1.5-flash")
|
||||
|
||||
for model in candidates:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": blocks}],
|
||||
max_tokens=512,
|
||||
)
|
||||
description = (response.choices[0].message.content or "").strip()
|
||||
if description:
|
||||
count = len(image_content)
|
||||
label = "image" if count == 1 else f"{count} images"
|
||||
return f"[{label} attached — description: {description}]"
|
||||
except Exception as exc:
|
||||
logger.debug("Vision fallback model '%s' failed: %s", model, exc)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TriggerEvent:
|
||||
"""A framework-level trigger signal (timer tick or webhook hit).
|
||||
@@ -90,7 +142,13 @@ class _EscalationReceiver:
|
||||
self._response: str | None = None
|
||||
self._awaiting_input = True # So inject_worker_message() can prefer us
|
||||
|
||||
async def inject_event(self, content: str, *, is_client_input: bool = False) -> None:
|
||||
async def inject_event(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict] | None = None,
|
||||
) -> None:
|
||||
"""Called by ExecutionStream.inject_input() when the user responds."""
|
||||
self._response = content
|
||||
self._event.set()
|
||||
@@ -243,7 +301,7 @@ class LoopConfig:
|
||||
# Maximum seconds a delegate_to_sub_agent call may run before being
|
||||
# killed. Subagents run a full event-loop so they naturally take
|
||||
# longer than a single tool call — default is 10 minutes. 0 = no timeout.
|
||||
subagent_timeout_seconds: float = 300.0
|
||||
subagent_timeout_seconds: float = 600.0
|
||||
|
||||
# --- Lifecycle hooks ---
|
||||
# Hooks are async callables keyed by event name. Supported events:
|
||||
@@ -293,13 +351,26 @@ class OutputAccumulator:
|
||||
|
||||
Values are stored in memory and optionally written through to a
|
||||
ConversationStore's cursor data for crash recovery.
|
||||
|
||||
When *spillover_dir* and *max_value_chars* are set, large values are
|
||||
automatically saved to files and replaced with lightweight file
|
||||
references. This guarantees auto-spill fires on **every** ``set()``
|
||||
call regardless of code path (resume, checkpoint restore, etc.).
|
||||
"""
|
||||
|
||||
values: dict[str, Any] = field(default_factory=dict)
|
||||
store: ConversationStore | None = None
|
||||
spillover_dir: str | None = None
|
||||
max_value_chars: int = 0 # 0 = disabled
|
||||
|
||||
async def set(self, key: str, value: Any) -> None:
|
||||
"""Set a key-value pair, persisting immediately if store is available."""
|
||||
"""Set a key-value pair, auto-spilling large values to files.
|
||||
|
||||
When the serialised value exceeds *max_value_chars*, the data is
|
||||
saved to ``<spillover_dir>/output_<key>.<ext>`` and *value* is
|
||||
replaced with a compact file-reference string.
|
||||
"""
|
||||
value = self._auto_spill(key, value)
|
||||
self.values[key] = value
|
||||
if self.store:
|
||||
cursor = await self.store.read_cursor() or {}
|
||||
@@ -308,6 +379,39 @@ class OutputAccumulator:
|
||||
cursor["outputs"] = outputs
|
||||
await self.store.write_cursor(cursor)
|
||||
|
||||
def _auto_spill(self, key: str, value: Any) -> Any:
|
||||
"""Save large values to a file and return a reference string."""
|
||||
if self.max_value_chars <= 0 or not self.spillover_dir:
|
||||
return value
|
||||
|
||||
val_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
||||
if len(val_str) <= self.max_value_chars:
|
||||
return value
|
||||
|
||||
spill_path = Path(self.spillover_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
write_content = (
|
||||
json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (spill_path / filename).stat().st_size
|
||||
logger.info(
|
||||
"set_output value auto-spilled: key=%s, %d chars → %s (%d bytes)",
|
||||
key,
|
||||
len(val_str),
|
||||
filename,
|
||||
file_size,
|
||||
)
|
||||
return (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') "
|
||||
f"to access full data.]"
|
||||
)
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key, or None if not present."""
|
||||
return self.values.get(key)
|
||||
@@ -380,7 +484,9 @@ class EventLoopNode(NodeProtocol):
|
||||
self._config = config or LoopConfig()
|
||||
self._tool_executor = tool_executor
|
||||
self._conversation_store = conversation_store
|
||||
self._injection_queue: asyncio.Queue[tuple[str, bool]] = asyncio.Queue()
|
||||
self._injection_queue: asyncio.Queue[tuple[str, bool, list[dict[str, Any]] | None]] = (
|
||||
asyncio.Queue()
|
||||
)
|
||||
self._trigger_queue: asyncio.Queue[TriggerEvent] = asyncio.Queue()
|
||||
# Client-facing input blocking state
|
||||
self._input_ready = asyncio.Event()
|
||||
@@ -421,6 +527,8 @@ class EventLoopNode(NodeProtocol):
|
||||
stream_id = ctx.stream_id or ctx.node_id
|
||||
node_id = ctx.node_id
|
||||
execution_id = ctx.execution_id or ""
|
||||
# Store skill dirs for AS-9 file-read interception in _execute_tool
|
||||
self._skill_dirs: list[str] = ctx.skill_dirs
|
||||
|
||||
# Verdict counters for runtime logging
|
||||
_accept_count = _retry_count = _escalate_count = _continue_count = 0
|
||||
@@ -467,7 +575,11 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation._output_keys = (
|
||||
ctx.cumulative_output_keys or ctx.node_spec.output_keys or None
|
||||
)
|
||||
accumulator = OutputAccumulator(store=self._conversation_store)
|
||||
accumulator = OutputAccumulator(
|
||||
store=self._conversation_store,
|
||||
spillover_dir=self._config.spillover_dir,
|
||||
max_value_chars=self._config.max_output_value_chars,
|
||||
)
|
||||
start_iteration = 0
|
||||
_restored_recent_responses: list[str] = []
|
||||
_restored_tool_fingerprints: list[list[tuple[str, str]]] = []
|
||||
@@ -481,12 +593,28 @@ class EventLoopNode(NodeProtocol):
|
||||
_restored_recent_responses = restored.recent_responses
|
||||
_restored_tool_fingerprints = restored.recent_tool_fingerprints
|
||||
|
||||
# Refresh the system prompt with full 3-layer composition.
|
||||
# The stored prompt may be stale after code changes or when
|
||||
# runtime-injected context (e.g. worker identity) has changed.
|
||||
# On resume, we rebuild identity + narrative + focus so the LLM
|
||||
# understands the session history, not just the node directive.
|
||||
from framework.graph.prompt_composer import compose_system_prompt
|
||||
# Refresh the system prompt with full composition including
|
||||
# execution preamble and node-type preamble. The stored
|
||||
# prompt may be stale after code changes or when runtime-
|
||||
# injected context (e.g. worker identity) has changed.
|
||||
from framework.graph.prompt_composer import (
|
||||
EXECUTION_SCOPE_PREAMBLE,
|
||||
compose_system_prompt,
|
||||
)
|
||||
|
||||
_exec_preamble = None
|
||||
if (
|
||||
not ctx.is_subagent_mode
|
||||
and ctx.node_spec.node_type in ("event_loop", "gcu")
|
||||
and ctx.node_spec.output_keys
|
||||
):
|
||||
_exec_preamble = EXECUTION_SCOPE_PREAMBLE
|
||||
|
||||
_node_type_preamble = None
|
||||
if ctx.node_spec.node_type == "gcu":
|
||||
from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT
|
||||
|
||||
_node_type_preamble = GCU_BROWSER_SYSTEM_PROMPT
|
||||
|
||||
_current_prompt = compose_system_prompt(
|
||||
identity_prompt=ctx.identity_prompt or None,
|
||||
@@ -495,6 +623,8 @@ class EventLoopNode(NodeProtocol):
|
||||
accounts_prompt=ctx.accounts_prompt or None,
|
||||
skills_catalog_prompt=ctx.skills_catalog_prompt or None,
|
||||
protocols_prompt=ctx.protocols_prompt or None,
|
||||
execution_preamble=_exec_preamble,
|
||||
node_type_preamble=_node_type_preamble,
|
||||
)
|
||||
if conversation.system_prompt != _current_prompt:
|
||||
conversation.update_system_prompt(_current_prompt)
|
||||
@@ -504,9 +634,21 @@ class EventLoopNode(NodeProtocol):
|
||||
_restored_tool_fingerprints = []
|
||||
|
||||
# Fresh conversation: either isolated mode or first node in continuous mode.
|
||||
from framework.graph.prompt_composer import _with_datetime
|
||||
from framework.graph.prompt_composer import (
|
||||
EXECUTION_SCOPE_PREAMBLE,
|
||||
_with_datetime,
|
||||
)
|
||||
|
||||
system_prompt = _with_datetime(ctx.node_spec.system_prompt or "")
|
||||
# Prepend execution-scope preamble for worker nodes so the
|
||||
# LLM knows it is one step in a pipeline and should not try
|
||||
# to perform work that belongs to other nodes.
|
||||
if (
|
||||
not ctx.is_subagent_mode
|
||||
and ctx.node_spec.node_type in ("event_loop", "gcu")
|
||||
and ctx.node_spec.output_keys
|
||||
):
|
||||
system_prompt = f"{EXECUTION_SCOPE_PREAMBLE}\n\n{system_prompt}"
|
||||
# Prepend GCU browser best-practices prompt for gcu nodes
|
||||
if ctx.node_spec.node_type == "gcu":
|
||||
from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT
|
||||
@@ -573,7 +715,11 @@ class EventLoopNode(NodeProtocol):
|
||||
# Stamp phase for first node in continuous mode
|
||||
if _is_continuous:
|
||||
conversation.set_current_phase(ctx.node_id)
|
||||
accumulator = OutputAccumulator(store=self._conversation_store)
|
||||
accumulator = OutputAccumulator(
|
||||
store=self._conversation_store,
|
||||
spillover_dir=self._config.spillover_dir,
|
||||
max_value_chars=self._config.max_output_value_chars,
|
||||
)
|
||||
start_iteration = 0
|
||||
|
||||
# Add initial user message from input data
|
||||
@@ -698,7 +844,7 @@ class EventLoopNode(NodeProtocol):
|
||||
)
|
||||
|
||||
# 6b. Drain injection queue
|
||||
await self._drain_injection_queue(conversation)
|
||||
await self._drain_injection_queue(conversation, ctx)
|
||||
# 6b1. Drain trigger queue (framework-level signals)
|
||||
await self._drain_trigger_queue(conversation)
|
||||
|
||||
@@ -740,6 +886,13 @@ class EventLoopNode(NodeProtocol):
|
||||
execution_id,
|
||||
extra_data=_iter_meta,
|
||||
)
|
||||
# Sync max_context_tokens from live config so mid-session model
|
||||
# switches are reflected in compaction decisions and the UI bar.
|
||||
from framework.config import get_max_context_tokens as _live_mct
|
||||
|
||||
conversation._max_context_tokens = _live_mct()
|
||||
|
||||
await self._publish_context_usage(ctx, conversation, "iteration_start")
|
||||
|
||||
# 6d. Pre-turn compaction check (tiered)
|
||||
_compacted_this_iter = False
|
||||
@@ -756,6 +909,7 @@ class EventLoopNode(NodeProtocol):
|
||||
)
|
||||
_stream_retry_count = 0
|
||||
_turn_cancelled = False
|
||||
_llm_turn_failed_waiting_input = False
|
||||
while True:
|
||||
try:
|
||||
(
|
||||
@@ -875,6 +1029,16 @@ class EventLoopNode(NodeProtocol):
|
||||
# can retry or adjust the request.
|
||||
if ctx.node_spec.client_facing:
|
||||
error_msg = f"LLM call failed: {e}"
|
||||
_guardrail_phrase = (
|
||||
"no endpoints available matching your guardrail restrictions "
|
||||
"and data policy"
|
||||
)
|
||||
if _guardrail_phrase in str(e).lower():
|
||||
error_msg += (
|
||||
" OpenRouter blocked this model under current privacy settings. "
|
||||
"Update https://openrouter.ai/settings/privacy or choose another "
|
||||
"OpenRouter model."
|
||||
)
|
||||
logger.error(
|
||||
"[%s] iter=%d: %s — waiting for user input",
|
||||
node_id,
|
||||
@@ -896,6 +1060,7 @@ class EventLoopNode(NodeProtocol):
|
||||
f"[Error: {error_msg}. Please try again.]"
|
||||
)
|
||||
await self._await_user_input(ctx, prompt="")
|
||||
_llm_turn_failed_waiting_input = True
|
||||
break # exit retry loop, continue outer iteration
|
||||
|
||||
# Non-client-facing: crash as before
|
||||
@@ -946,6 +1111,11 @@ class EventLoopNode(NodeProtocol):
|
||||
await self._await_user_input(ctx, prompt="")
|
||||
continue # back to top of for-iteration loop
|
||||
|
||||
# Client-facing non-transient LLM failures wait for user input and then
|
||||
# continue the outer loop without touching per-turn token vars.
|
||||
if _llm_turn_failed_waiting_input:
|
||||
continue
|
||||
|
||||
# 6e'. Feed actual API token count back for accurate estimation
|
||||
turn_input = turn_tokens.get("input", 0)
|
||||
if turn_input > 0:
|
||||
@@ -1800,7 +1970,13 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation=conversation if _is_continuous else None,
|
||||
)
|
||||
|
||||
async def inject_event(self, content: str, *, is_client_input: bool = False) -> None:
|
||||
async def inject_event(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Inject an external event or user input into the running loop.
|
||||
|
||||
The content becomes a user message prepended to the next iteration.
|
||||
@@ -1816,8 +1992,10 @@ class EventLoopNode(NodeProtocol):
|
||||
human user (e.g. /chat endpoint), False for external events
|
||||
(e.g. worker question forwarded by the frontend). Controls
|
||||
message formatting in _drain_injection_queue, not wake behavior.
|
||||
image_content: Optional list of image content blocks (OpenAI
|
||||
image_url format) to include alongside the text.
|
||||
"""
|
||||
await self._injection_queue.put((content, is_client_input))
|
||||
await self._injection_queue.put((content, is_client_input, image_content))
|
||||
self._input_ready.set()
|
||||
|
||||
async def inject_trigger(self, trigger: TriggerEvent) -> None:
|
||||
@@ -1991,6 +2169,24 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
messages = conversation.to_llm_messages()
|
||||
|
||||
# Debug: log whether the last user message contains image blocks
|
||||
for _m in reversed(messages):
|
||||
if _m.get("role") == "user":
|
||||
_content = _m.get("content")
|
||||
if isinstance(_content, list):
|
||||
_img_count = sum(
|
||||
1
|
||||
for _b in _content
|
||||
if isinstance(_b, dict) and _b.get("type") == "image_url"
|
||||
)
|
||||
if _img_count:
|
||||
logger.info(
|
||||
"[%s] LLM call: last user message has %d image block(s)",
|
||||
node_id,
|
||||
_img_count,
|
||||
)
|
||||
break
|
||||
|
||||
# Defensive guard: ensure messages don't end with an assistant
|
||||
# message. The Anthropic API rejects "assistant message prefill"
|
||||
# (conversations must end with a user or tool message). This can
|
||||
@@ -2197,58 +2393,24 @@ class EventLoopNode(NodeProtocol):
|
||||
pass
|
||||
key = tc.tool_input.get("key", "")
|
||||
|
||||
# Auto-spill: save large values to data files and
|
||||
# replace with a lightweight file reference so shared
|
||||
# memory / adapt.md / transition markers stay small.
|
||||
spill_dir = self._config.spillover_dir
|
||||
max_val = self._config.max_output_value_chars
|
||||
if max_val > 0 and spill_dir:
|
||||
val_str = (
|
||||
json.dumps(value, ensure_ascii=False)
|
||||
if not isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
if len(val_str) > max_val:
|
||||
spill_path = Path(spill_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
write_content = (
|
||||
json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (spill_path / filename).stat().st_size
|
||||
logger.info(
|
||||
"set_output value auto-spilled: key=%s, "
|
||||
"%d chars → %s (%d bytes)",
|
||||
key,
|
||||
len(val_str),
|
||||
filename,
|
||||
file_size,
|
||||
)
|
||||
# Replace value with reference
|
||||
value = (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') "
|
||||
f"to access full data.]"
|
||||
)
|
||||
# Update tool result to inform the LLM
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
f"Output '{key}' was large "
|
||||
f"({len(val_str):,} chars) — data saved "
|
||||
f"to '{filename}' ({file_size:,} bytes). "
|
||||
f"The next phase will see the file "
|
||||
f"reference and can load full data."
|
||||
),
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
# Auto-spill happens inside accumulator.set()
|
||||
# — it fires on every code path (fresh, resume,
|
||||
# restore) and prevents overwrite regression.
|
||||
await accumulator.set(key, value)
|
||||
self._record_learning(key, value)
|
||||
stored = accumulator.get(key)
|
||||
# If the accumulator spilled, update the tool
|
||||
# result so the LLM knows data was saved to a file.
|
||||
if isinstance(stored, str) and stored.startswith("[Saved to '"):
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
f"Output '{key}' auto-saved to file "
|
||||
f"(value was too large for inline). "
|
||||
f"{stored}"
|
||||
),
|
||||
is_error=False,
|
||||
)
|
||||
self._record_learning(key, stored)
|
||||
outputs_set_this_turn.append(key)
|
||||
await self._publish_output_key_set(stream_id, node_id, key, execution_id)
|
||||
logged_tool_calls.append(
|
||||
@@ -2266,7 +2428,6 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
elif tc.tool_name == "ask_user":
|
||||
# --- Framework-level ask_user handling ---
|
||||
user_input_requested = True
|
||||
ask_user_prompt = tc.tool_input.get("question", "")
|
||||
raw_options = tc.tool_input.get("options", None)
|
||||
# Defensive: ensure options is a list of strings.
|
||||
@@ -2303,6 +2464,8 @@ class EventLoopNode(NodeProtocol):
|
||||
user_input_requested = False
|
||||
continue
|
||||
|
||||
user_input_requested = True
|
||||
|
||||
# Free-form ask_user (no options): stream the question
|
||||
# text as a chat message so the user can see it. When
|
||||
# options are present the QuestionWidget shows the
|
||||
@@ -2328,7 +2491,6 @@ class EventLoopNode(NodeProtocol):
|
||||
|
||||
elif tc.tool_name == "ask_user_multiple":
|
||||
# --- Framework-level ask_user_multiple ---
|
||||
user_input_requested = True
|
||||
raw_questions = tc.tool_input.get("questions", [])
|
||||
if not isinstance(raw_questions, list) or len(raw_questions) < 2:
|
||||
result = ToolResult(
|
||||
@@ -2366,6 +2528,8 @@ class EventLoopNode(NodeProtocol):
|
||||
}
|
||||
)
|
||||
|
||||
user_input_requested = True
|
||||
|
||||
# Store as multi-question prompt/options for
|
||||
# the event emission path
|
||||
ask_user_prompt = ""
|
||||
@@ -2426,6 +2590,27 @@ class EventLoopNode(NodeProtocol):
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
|
||||
elif tc.tool_name == "delegate_to_sub_agent":
|
||||
# Guard: in continuous mode the LLM may see delegate
|
||||
# calls from a previous node's conversation history and
|
||||
# attempt to re-use the tool on a node that doesn't own
|
||||
# it. Only accept if the tool was actually offered.
|
||||
if not any(t.name == "delegate_to_sub_agent" for t in tools):
|
||||
logger.warning(
|
||||
"[%s] LLM called delegate_to_sub_agent but tool "
|
||||
"was not offered to this node — rejecting",
|
||||
node_id,
|
||||
)
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
"ERROR: delegate_to_sub_agent is not available "
|
||||
"on this node. This tool belongs to a different "
|
||||
"node in the workflow."
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
continue
|
||||
# --- Framework-level subagent delegation ---
|
||||
# Queue for parallel execution in Phase 2
|
||||
logger.info(
|
||||
@@ -2627,6 +2812,11 @@ class EventLoopNode(NodeProtocol):
|
||||
content=raw.content,
|
||||
is_error=raw.is_error,
|
||||
)
|
||||
# Route through _truncate_tool_result so large
|
||||
# subagent results are saved to spillover files
|
||||
# and survive pruning (instead of being "cleared
|
||||
# from context" with no recovery path).
|
||||
result = self._truncate_tool_result(result, "delegate_to_sub_agent")
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
logged_tool_calls.append(
|
||||
{
|
||||
@@ -2666,12 +2856,28 @@ class EventLoopNode(NodeProtocol):
|
||||
real_tool_results.append(tool_entry)
|
||||
logged_tool_calls.append(tool_entry)
|
||||
|
||||
# Strip image content for models that can't handle it
|
||||
image_content = result.image_content
|
||||
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
|
||||
logger.info(
|
||||
"Stripping image_content from tool result — model '%s' "
|
||||
"does not support images in tool results",
|
||||
ctx.llm.model,
|
||||
)
|
||||
image_content = None
|
||||
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
image_content=image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
if tc.tool_name in ("ask_user", "ask_user_multiple"):
|
||||
if (
|
||||
tc.tool_name in ("ask_user", "ask_user_multiple")
|
||||
and user_input_requested
|
||||
and not result.is_error
|
||||
):
|
||||
# Defer tool_call_completed until after user responds
|
||||
self._deferred_tool_complete = {
|
||||
"stream_id": stream_id,
|
||||
@@ -2774,6 +2980,8 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
|
||||
await self._publish_context_usage(ctx, conversation, "post_tool_results")
|
||||
|
||||
# If the turn requested external input (ask_user or queen handoff),
|
||||
# return immediately so the outer loop can block before judge eval.
|
||||
if user_input_requested or queen_input_requested:
|
||||
@@ -3489,6 +3697,33 @@ class EventLoopNode(NodeProtocol):
|
||||
content=f"No tool executor configured for '{tc.tool_name}'",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
# AS-9: Intercept file-read tools for skill directories — bypass session sandbox
|
||||
_SKILL_READ_TOOLS = {"view_file", "load_data", "read_file"}
|
||||
skill_dirs = getattr(self, "_skill_dirs", [])
|
||||
if tc.tool_name in _SKILL_READ_TOOLS and skill_dirs:
|
||||
_path = tc.tool_input.get("path", "")
|
||||
if _path:
|
||||
import os
|
||||
from pathlib import Path as _Path
|
||||
|
||||
_resolved = os.path.realpath(os.path.abspath(_path))
|
||||
if any(_resolved.startswith(os.path.realpath(d)) for d in skill_dirs):
|
||||
try:
|
||||
_content = _Path(_resolved).read_text(encoding="utf-8")
|
||||
_is_skill_md = _resolved.endswith("SKILL.md")
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=_content,
|
||||
is_skill_content=_is_skill_md, # AS-10: protect SKILL.md reads
|
||||
)
|
||||
except Exception as _exc:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"Could not read skill resource '{_path}': {_exc}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
|
||||
timeout = self._config.tool_call_timeout_seconds
|
||||
|
||||
@@ -3776,6 +4011,7 @@ class EventLoopNode(NodeProtocol):
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
)
|
||||
|
||||
spill_dir = self._config.spillover_dir
|
||||
@@ -3848,6 +4084,7 @@ class EventLoopNode(NodeProtocol):
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=content,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
)
|
||||
|
||||
# No spillover_dir — truncate in-place if needed
|
||||
@@ -3890,6 +4127,7 @@ class EventLoopNode(NodeProtocol):
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -3920,6 +4158,12 @@ class EventLoopNode(NodeProtocol):
|
||||
ratio_before = conversation.usage_ratio()
|
||||
phase_grad = getattr(ctx, "continuous_mode", False)
|
||||
|
||||
# Capture pre-compaction message inventory when over budget,
|
||||
# since compaction mutates the conversation in place.
|
||||
pre_inventory: list[dict[str, Any]] | None = None
|
||||
if ratio_before >= 1.0:
|
||||
pre_inventory = self._build_message_inventory(conversation)
|
||||
|
||||
# --- Step 1: Prune old tool results (free, no LLM) ---
|
||||
protect = max(2000, self._config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
@@ -3934,7 +4178,7 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
return
|
||||
|
||||
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
|
||||
@@ -3947,7 +4191,7 @@ class EventLoopNode(NodeProtocol):
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
return
|
||||
|
||||
# --- Step 3: LLM summary compaction ---
|
||||
@@ -3974,7 +4218,7 @@ class EventLoopNode(NodeProtocol):
|
||||
logger.warning("LLM compaction failed: %s", e)
|
||||
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
return
|
||||
|
||||
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
|
||||
@@ -3988,7 +4232,7 @@ class EventLoopNode(NodeProtocol):
|
||||
keep_recent=1,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
|
||||
# --- LLM compaction with binary-search splitting ----------------------
|
||||
|
||||
@@ -4150,13 +4394,59 @@ class EventLoopNode(NodeProtocol):
|
||||
"re-doing work.\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_message_inventory(
|
||||
conversation: NodeConversation,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a per-message size inventory for debug logging."""
|
||||
inventory: list[dict[str, Any]] = []
|
||||
for m in conversation.messages:
|
||||
content_chars = len(m.content)
|
||||
tc_chars = 0
|
||||
tool_name = None
|
||||
if m.tool_calls:
|
||||
for tc in m.tool_calls:
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
tc_chars += len(args) if isinstance(args, str) else len(json.dumps(args))
|
||||
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
|
||||
tool_name = ", ".join(names)
|
||||
elif m.role == "tool" and m.tool_use_id:
|
||||
for prev in conversation.messages:
|
||||
if prev.tool_calls:
|
||||
for tc in prev.tool_calls:
|
||||
if tc.get("id") == m.tool_use_id:
|
||||
tool_name = tc.get("function", {}).get("name", "?")
|
||||
break
|
||||
if tool_name:
|
||||
break
|
||||
entry: dict[str, Any] = {
|
||||
"seq": m.seq,
|
||||
"role": m.role,
|
||||
"content_chars": content_chars,
|
||||
}
|
||||
if tc_chars:
|
||||
entry["tool_call_args_chars"] = tc_chars
|
||||
if tool_name:
|
||||
entry["tool"] = tool_name
|
||||
if m.is_error:
|
||||
entry["is_error"] = True
|
||||
if m.phase_id:
|
||||
entry["phase"] = m.phase_id
|
||||
if content_chars > 2000:
|
||||
entry["preview"] = m.content[:200] + "…"
|
||||
inventory.append(entry)
|
||||
return inventory
|
||||
|
||||
async def _log_compaction(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
ratio_before: float,
|
||||
pre_inventory: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Log compaction result to runtime logger and event bus."""
|
||||
"""Log compaction result to runtime logger, event bus, and debug file."""
|
||||
import os as _os
|
||||
|
||||
ratio_after = conversation.usage_ratio()
|
||||
before_pct = round(ratio_before * 100)
|
||||
after_pct = round(ratio_after * 100)
|
||||
@@ -4189,19 +4479,103 @@ class EventLoopNode(NodeProtocol):
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
event_data: dict[str, Any] = {
|
||||
"level": level,
|
||||
"usage_before": before_pct,
|
||||
"usage_after": after_pct,
|
||||
}
|
||||
if pre_inventory is not None:
|
||||
event_data["message_inventory"] = pre_inventory
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_COMPACTED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"level": level,
|
||||
"usage_before": before_pct,
|
||||
"usage_after": after_pct,
|
||||
},
|
||||
data=event_data,
|
||||
)
|
||||
)
|
||||
|
||||
# Emit post-compaction usage update
|
||||
await self._publish_context_usage(ctx, conversation, "post_compaction")
|
||||
|
||||
# Write detailed debug log to ~/.hive/compaction_log/ when enabled
|
||||
if _os.environ.get("HIVE_COMPACTION_DEBUG"):
|
||||
self._write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory)
|
||||
|
||||
@staticmethod
|
||||
def _write_compaction_debug_log(
|
||||
ctx: NodeContext,
|
||||
before_pct: int,
|
||||
after_pct: int,
|
||||
level: str,
|
||||
inventory: list[dict[str, Any]] | None,
|
||||
) -> None:
|
||||
"""Write detailed compaction analysis to ~/.hive/compaction_log/."""
|
||||
log_dir = Path.home() / ".hive" / "compaction_log"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
|
||||
node_label = ctx.node_id.replace("/", "_")
|
||||
log_path = log_dir / f"{ts}_{node_label}.md"
|
||||
|
||||
lines: list[str] = [
|
||||
f"# Compaction Debug — {ctx.node_id}",
|
||||
f"**Time:** {datetime.now(UTC).isoformat()}",
|
||||
f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)",
|
||||
]
|
||||
if ctx.stream_id:
|
||||
lines.append(f"**Stream:** {ctx.stream_id}")
|
||||
lines.append(f"**Level:** {level}")
|
||||
lines.append(f"**Usage:** {before_pct}% → {after_pct}%")
|
||||
lines.append("")
|
||||
|
||||
if inventory:
|
||||
total_chars = sum(
|
||||
e.get("content_chars", 0) + e.get("tool_call_args_chars", 0) for e in inventory
|
||||
)
|
||||
lines.append(
|
||||
f"## Pre-Compaction Message Inventory "
|
||||
f"({len(inventory)} messages, {total_chars:,} total chars)"
|
||||
)
|
||||
lines.append("")
|
||||
ranked = sorted(
|
||||
inventory,
|
||||
key=lambda e: e.get("content_chars", 0) + e.get("tool_call_args_chars", 0),
|
||||
reverse=True,
|
||||
)
|
||||
lines.append("| # | seq | role | tool | chars | % of total | flags |")
|
||||
lines.append("|---|-----|------|------|------:|------------|-------|")
|
||||
for i, entry in enumerate(ranked, 1):
|
||||
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
|
||||
pct = (chars / total_chars * 100) if total_chars else 0
|
||||
tool = entry.get("tool", "")
|
||||
flags = []
|
||||
if entry.get("is_error"):
|
||||
flags.append("error")
|
||||
if entry.get("phase"):
|
||||
flags.append(f"phase={entry['phase']}")
|
||||
lines.append(
|
||||
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
|
||||
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
|
||||
)
|
||||
|
||||
large = [e for e in ranked if e.get("preview")]
|
||||
if large:
|
||||
lines.append("")
|
||||
lines.append("### Large message previews")
|
||||
for entry in large:
|
||||
lines.append(
|
||||
f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):"
|
||||
)
|
||||
lines.append(f"```\n{entry['preview']}\n```")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
log_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
logger.debug("Compaction debug log written to %s", log_path)
|
||||
except OSError:
|
||||
logger.debug("Failed to write compaction debug log to %s", log_path)
|
||||
|
||||
def _build_emergency_summary(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
@@ -4287,17 +4661,14 @@ class EventLoopNode(NodeProtocol):
|
||||
)
|
||||
parts.append(
|
||||
"CONVERSATION HISTORY (freeform messages saved during compaction — "
|
||||
"use load_data('<filename>'), read_file('<full_path>'), "
|
||||
"or run_command('cat \"<full_path>\"') to review earlier dialogue):\n"
|
||||
+ conv_list
|
||||
"use load_data('<filename>') to review earlier dialogue):\n" + conv_list
|
||||
)
|
||||
if data_files:
|
||||
file_list = "\n".join(
|
||||
f" - {f} (full path: {data_dir / f})" for f in data_files[:30]
|
||||
)
|
||||
parts.append(
|
||||
"DATA FILES (use load_data('<filename>'), read_file('<full_path>'), "
|
||||
"or run_command('cat \"<full_path>\"') to read):\n" + file_list
|
||||
"DATA FILES (use load_data('<filename>') to read):\n" + file_list
|
||||
)
|
||||
if not all_files:
|
||||
parts.append(
|
||||
@@ -4363,6 +4734,8 @@ class EventLoopNode(NodeProtocol):
|
||||
return None
|
||||
|
||||
accumulator = await OutputAccumulator.restore(self._conversation_store)
|
||||
accumulator.spillover_dir = self._config.spillover_dir
|
||||
accumulator.max_value_chars = self._config.max_output_value_chars
|
||||
|
||||
cursor = await self._conversation_store.read_cursor()
|
||||
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
|
||||
@@ -4425,20 +4798,37 @@ class EventLoopNode(NodeProtocol):
|
||||
]
|
||||
await self._conversation_store.write_cursor(cursor)
|
||||
|
||||
async def _drain_injection_queue(self, conversation: NodeConversation) -> int:
|
||||
async def _drain_injection_queue(self, conversation: NodeConversation, ctx: NodeContext) -> int:
|
||||
"""Drain all pending injected events as user messages. Returns count."""
|
||||
count = 0
|
||||
while not self._injection_queue.empty():
|
||||
try:
|
||||
content, is_client_input = self._injection_queue.get_nowait()
|
||||
content, is_client_input, image_content = self._injection_queue.get_nowait()
|
||||
logger.info(
|
||||
"[drain] injected message (client_input=%s): %s",
|
||||
"[drain] injected message (client_input=%s, images=%d): %s",
|
||||
is_client_input,
|
||||
len(image_content) if image_content else 0,
|
||||
content[:200] if content else "(empty)",
|
||||
)
|
||||
# For models that don't support images, fall back to a text description
|
||||
if image_content and ctx.llm:
|
||||
if not supports_image_tool_results(ctx.llm.model):
|
||||
logger.info(
|
||||
"Model '%s' does not support images — attempting vision fallback",
|
||||
ctx.llm.model,
|
||||
)
|
||||
description = await _describe_images_as_text(image_content)
|
||||
if description:
|
||||
content = f"{content}\n\n{description}" if content else description
|
||||
logger.info("[drain] image described as text via vision fallback")
|
||||
else:
|
||||
logger.info("[drain] no vision fallback available — images dropped")
|
||||
image_content = None
|
||||
# Real user input is stored as-is; external events get a prefix
|
||||
if is_client_input:
|
||||
await conversation.add_user_message(content, is_client_input=True)
|
||||
await conversation.add_user_message(
|
||||
content, is_client_input=True, image_content=image_content
|
||||
)
|
||||
else:
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
count += 1
|
||||
@@ -4607,6 +4997,36 @@ class EventLoopNode(NodeProtocol):
|
||||
if result.inject:
|
||||
await conversation.add_user_message(result.inject)
|
||||
|
||||
async def _publish_context_usage(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
trigger: str,
|
||||
) -> None:
|
||||
"""Emit a CONTEXT_USAGE_UPDATED event with current context window state."""
|
||||
if not self._event_bus:
|
||||
return
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
estimated = conversation.estimate_tokens()
|
||||
max_tokens = conversation._max_context_tokens
|
||||
ratio = estimated / max_tokens if max_tokens > 0 else 0.0
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_USAGE_UPDATED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"usage_ratio": round(ratio, 4),
|
||||
"usage_pct": round(ratio * 100),
|
||||
"message_count": conversation.message_count,
|
||||
"estimated_tokens": estimated,
|
||||
"max_context_tokens": max_tokens,
|
||||
"trigger": trigger,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def _publish_iteration(
|
||||
self,
|
||||
stream_id: str,
|
||||
@@ -4891,7 +5311,20 @@ class EventLoopNode(NodeProtocol):
|
||||
write_keys=[], # Read-only!
|
||||
)
|
||||
|
||||
# 2b. Set up report callback (one-way channel to parent / event bus)
|
||||
# 2b. Compute instance counter early so node_id is available for the
|
||||
# report callback and the NodeContext. Each delegation to the same
|
||||
# agent_id gets a unique suffix (instance 1 has no suffix for backward
|
||||
# compat; instance 2+ appends ":N").
|
||||
self._subagent_instance_counter.setdefault(agent_id, 0)
|
||||
self._subagent_instance_counter[agent_id] += 1
|
||||
_sa_instance = self._subagent_instance_counter[agent_id]
|
||||
if _sa_instance > 1:
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}:{_sa_instance}"
|
||||
else:
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}"
|
||||
subagent_instance = str(_sa_instance)
|
||||
|
||||
# 2c. Set up report callback (one-way channel to parent / event bus)
|
||||
subagent_reports: list[dict] = []
|
||||
|
||||
async def _report_callback(
|
||||
@@ -4904,7 +5337,7 @@ class EventLoopNode(NodeProtocol):
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_subagent_report(
|
||||
stream_id=ctx.node_id,
|
||||
node_id=f"{ctx.node_id}:subagent:{agent_id}",
|
||||
node_id=sa_node_id,
|
||||
subagent_id=agent_id,
|
||||
message=message,
|
||||
data=data,
|
||||
@@ -4994,7 +5427,7 @@ class EventLoopNode(NodeProtocol):
|
||||
max_iter = min(self._config.max_iterations, 10)
|
||||
subagent_ctx = NodeContext(
|
||||
runtime=ctx.runtime,
|
||||
node_id=f"{ctx.node_id}:subagent:{agent_id}",
|
||||
node_id=sa_node_id,
|
||||
node_spec=subagent_spec,
|
||||
memory=scoped_memory,
|
||||
input_data={"task": task, **parent_data},
|
||||
@@ -5022,10 +5455,7 @@ class EventLoopNode(NodeProtocol):
|
||||
# Derive a conversation store for the subagent from the parent's store.
|
||||
# Each invocation gets a unique path so that repeated delegate calls
|
||||
# (e.g. one per profile) don't restore a stale completed conversation.
|
||||
self._subagent_instance_counter.setdefault(agent_id, 0)
|
||||
self._subagent_instance_counter[agent_id] += 1
|
||||
subagent_instance = str(self._subagent_instance_counter[agent_id])
|
||||
|
||||
# (Instance counter was computed earlier in step 2b.)
|
||||
subagent_conv_store = None
|
||||
if self._conversation_store is not None:
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
@@ -154,6 +154,7 @@ class GraphExecutor:
|
||||
iteration_metadata_provider: Callable | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -181,6 +182,7 @@ class GraphExecutor:
|
||||
system prompt (for phase switching)
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -204,6 +206,7 @@ class GraphExecutor:
|
||||
self.iteration_metadata_provider = iteration_metadata_provider
|
||||
self.skills_catalog_prompt = skills_catalog_prompt
|
||||
self.protocols_prompt = protocols_prompt
|
||||
self.skill_dirs: list[str] = skill_dirs or []
|
||||
|
||||
if protocols_prompt:
|
||||
self.logger.info(
|
||||
@@ -1420,6 +1423,7 @@ class GraphExecutor:
|
||||
next_spec = graph.get_node(current_node_id)
|
||||
if next_spec and next_spec.node_type == "event_loop":
|
||||
from framework.graph.prompt_composer import (
|
||||
EXECUTION_SCOPE_PREAMBLE,
|
||||
build_accounts_prompt,
|
||||
build_narrative,
|
||||
build_transition_marker,
|
||||
@@ -1459,9 +1463,14 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Compose new system prompt (Layer 1 + 2 + 3 + accounts)
|
||||
# Prepend scope preamble to focus so the LLM stays
|
||||
# within this node's responsibility.
|
||||
_focus = next_spec.system_prompt
|
||||
if next_spec.output_keys and _focus:
|
||||
_focus = f"{EXECUTION_SCOPE_PREAMBLE}\n\n{_focus}"
|
||||
new_system = compose_system_prompt(
|
||||
identity_prompt=getattr(graph, "identity_prompt", None),
|
||||
focus_prompt=next_spec.system_prompt,
|
||||
focus_prompt=_focus,
|
||||
narrative=narrative,
|
||||
accounts_prompt=_node_accounts,
|
||||
)
|
||||
@@ -1839,6 +1848,9 @@ class GraphExecutor:
|
||||
|
||||
existing_underscore = [k for k in memory._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
# Only inject into read_keys when it was already non-empty — an empty
|
||||
# read_keys means "allow all reads" and injecting skill keys would
|
||||
# inadvertently restrict reads to skill keys only.
|
||||
for k in extra_keys:
|
||||
if read_keys and k not in read_keys:
|
||||
read_keys.append(k)
|
||||
@@ -1893,6 +1905,7 @@ class GraphExecutor:
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
|
||||
@@ -43,8 +43,11 @@ Follow these rules for reliable, efficient browser interaction.
|
||||
`browser_snapshot` separately after every action.
|
||||
Only call `browser_snapshot` when you need a fresh view without
|
||||
performing an action, or after setting `auto_snapshot=false`.
|
||||
- Do NOT use `browser_screenshot` for reading text content
|
||||
— it produces huge base64 images with no searchable text.
|
||||
- Do NOT use `browser_screenshot` to read text — use
|
||||
`browser_snapshot` for that (compact, searchable, fast).
|
||||
- DO use `browser_screenshot` when you need visual context:
|
||||
charts, images, canvas elements, layout verification, or when
|
||||
the snapshot doesn't capture what you need.
|
||||
- Only fall back to `browser_get_text` for extracting specific
|
||||
small elements by CSS selector.
|
||||
|
||||
|
||||
@@ -167,14 +167,6 @@ class Goal(BaseModel):
|
||||
|
||||
return met_weight >= total_weight * 0.9 # 90% threshold
|
||||
|
||||
def check_constraint(self, constraint_id: str, value: Any) -> bool:
|
||||
"""Check if a specific constraint is satisfied."""
|
||||
for c in self.constraints:
|
||||
if c.id == constraint_id:
|
||||
# This would be expanded with actual evaluation logic
|
||||
return True
|
||||
return True
|
||||
|
||||
def to_prompt_context(self) -> str:
|
||||
"""Generate context string for LLM prompts.
|
||||
|
||||
|
||||
@@ -568,6 +568,7 @@ class NodeContext:
|
||||
# Skill system prompts — injected by the skill discovery pipeline
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
protocols_prompt: str = "" # Default skill operational protocols
|
||||
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
|
||||
|
||||
# Per-iteration metadata provider — when set, EventLoopNode merges
|
||||
# the returned dict into node_loop_iteration event data. Used by
|
||||
|
||||
@@ -26,6 +26,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Injected into every worker node's system prompt so the LLM understands
|
||||
# it is one step in a multi-node pipeline and should not overreach.
|
||||
EXECUTION_SCOPE_PREAMBLE = (
|
||||
"EXECUTION SCOPE: You are one node in a multi-step workflow graph. "
|
||||
"Focus ONLY on the task described in your instructions below. "
|
||||
"Call set_output() for each of your declared output keys, then stop. "
|
||||
"Do NOT attempt work that belongs to other nodes — the framework "
|
||||
"routes data between nodes automatically."
|
||||
)
|
||||
|
||||
|
||||
def _with_datetime(prompt: str) -> str:
|
||||
"""Append current datetime with local timezone to a system prompt."""
|
||||
@@ -142,6 +152,8 @@ def compose_system_prompt(
|
||||
accounts_prompt: str | None = None,
|
||||
skills_catalog_prompt: str | None = None,
|
||||
protocols_prompt: str | None = None,
|
||||
execution_preamble: str | None = None,
|
||||
node_type_preamble: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the multi-layer system prompt.
|
||||
|
||||
@@ -152,6 +164,10 @@ def compose_system_prompt(
|
||||
accounts_prompt: Connected accounts block (sits between identity and narrative).
|
||||
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
|
||||
protocols_prompt: Default skill operational protocols section.
|
||||
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
|
||||
(prepended before focus so the LLM knows its pipeline scope).
|
||||
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
|
||||
best-practices prompt (prepended before focus).
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present, plus current datetime.
|
||||
@@ -178,6 +194,15 @@ def compose_system_prompt(
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
|
||||
# Execution scope preamble (worker nodes — tells the LLM it is one
|
||||
# step in a multi-node pipeline and should not overreach)
|
||||
if execution_preamble:
|
||||
parts.append(f"\n{execution_preamble}")
|
||||
|
||||
# Node-type preamble (e.g. GCU browser best-practices)
|
||||
if node_type_preamble:
|
||||
parts.append(f"\n{node_type_preamble}")
|
||||
|
||||
# Layer 3: Focus (current phase directive)
|
||||
if focus_prompt:
|
||||
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
|
||||
@@ -267,7 +292,9 @@ def build_transition_marker(
|
||||
sections.append(f"\nCompleted: {previous_node.name}")
|
||||
sections.append(f" {previous_node.description}")
|
||||
|
||||
# Outputs in memory
|
||||
# Outputs in memory — use file references for large values so the
|
||||
# next node loads full data from disk instead of seeing truncated
|
||||
# inline previews that look deceptively complete.
|
||||
all_memory = memory.read_all()
|
||||
if all_memory:
|
||||
memory_lines: list[str] = []
|
||||
@@ -275,7 +302,29 @@ def build_transition_marker(
|
||||
if value is None:
|
||||
continue
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
if len(val_str) > 300 and data_dir:
|
||||
# Auto-spill large transition values to data files
|
||||
import json as _json
|
||||
|
||||
data_path = Path(data_dir)
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
try:
|
||||
write_content = (
|
||||
_json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(data_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (data_path / filename).stat().st_size
|
||||
val_str = (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') to access.]"
|
||||
)
|
||||
except Exception:
|
||||
val_str = val_str[:300] + "..."
|
||||
elif len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
memory_lines.append(f" {key}: {val_str}")
|
||||
if memory_lines:
|
||||
@@ -292,7 +341,7 @@ def build_transition_marker(
|
||||
]
|
||||
if file_lines:
|
||||
sections.append(
|
||||
"\nData files (use read_file to access):\n" + "\n".join(file_lines)
|
||||
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
|
||||
)
|
||||
|
||||
# Agent working memory
|
||||
@@ -306,6 +355,12 @@ def build_transition_marker(
|
||||
# Next phase
|
||||
sections.append(f"\nNow entering: {next_node.name}")
|
||||
sections.append(f" {next_node.description}")
|
||||
if next_node.output_keys:
|
||||
sections.append(
|
||||
f"\nYour ONLY job in this phase: complete the task above and call "
|
||||
f"set_output() for {next_node.output_keys}. Do NOT do work that "
|
||||
f"belongs to later phases."
|
||||
)
|
||||
|
||||
# Reflection prompt (engineered metacognition)
|
||||
sections.append(
|
||||
|
||||
@@ -115,11 +115,23 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
return True
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
|
||||
values = [self.visit(v) for v in node.values]
|
||||
# Short-circuit evaluation to match Python semantics.
|
||||
# Previously all operands were eagerly evaluated, which broke
|
||||
# guard patterns like: ``x is not None and x.get("key")``
|
||||
if isinstance(node.op, ast.And):
|
||||
return all(values)
|
||||
result = True
|
||||
for v in node.values:
|
||||
result = self.visit(v)
|
||||
if not result:
|
||||
return result
|
||||
return result
|
||||
elif isinstance(node.op, ast.Or):
|
||||
return any(values)
|
||||
result = False
|
||||
for v in node.values:
|
||||
result = self.visit(v)
|
||||
if result:
|
||||
return result
|
||||
return result
|
||||
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
|
||||
|
||||
def visit_IfExp(self, node: ast.IfExp) -> Any:
|
||||
|
||||
@@ -0,0 +1,706 @@
|
||||
"""Antigravity (Google internal Cloud Code Assist) LLM provider.
|
||||
|
||||
Antigravity is Google's unified gateway API that routes requests to Gemini,
|
||||
Claude, and GPT-OSS models through a single Gemini-style interface. It is
|
||||
NOT the public ``generativelanguage.googleapis.com`` API.
|
||||
|
||||
Authentication uses Google OAuth2. Token refresh is done directly with the
|
||||
OAuth client secret — no local proxy required.
|
||||
|
||||
Credential sources (checked in order):
|
||||
1. ``~/.hive/antigravity-accounts.json`` (native OAuth implementation)
|
||||
2. Antigravity IDE SQLite state DB (macOS / Linux)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Fallback order: daily sandbox → autopush sandbox → production
|
||||
_ENDPOINTS = [
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com",
|
||||
"https://autopush-cloudcode-pa.sandbox.googleapis.com",
|
||||
"https://cloudcode-pa.googleapis.com",
|
||||
]
|
||||
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
|
||||
_TOKEN_REFRESH_BUFFER_SECS = 60
|
||||
|
||||
# Credentials file in ~/.hive/ (native implementation)
|
||||
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
_IDE_STATE_DB_MAC = (
|
||||
Path.home()
|
||||
/ "Library"
|
||||
/ "Application Support"
|
||||
/ "Antigravity"
|
||||
/ "User"
|
||||
/ "globalStorage"
|
||||
/ "state.vscdb"
|
||||
)
|
||||
_IDE_STATE_DB_LINUX = (
|
||||
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
|
||||
)
|
||||
_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
|
||||
|
||||
_BASE_HEADERS: dict[str, str] = {
|
||||
# Mimic the Antigravity Electron app so the API accepts the request.
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Antigravity/1.18.3 Chrome/138.0.7204.235 "
|
||||
"Electron/37.3.1 Safari/537.36"
|
||||
),
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"Client-Metadata": '{"ideType":"ANTIGRAVITY","platform":"MACOS","pluginType":"GEMINI"}',
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_from_json_file() -> tuple[str | None, str | None, str, float]:
|
||||
"""Read credentials from JSON accounts file.
|
||||
|
||||
Reads from ~/.hive/antigravity-accounts.json.
|
||||
|
||||
Returns ``(access_token | None, refresh_token | None, project_id, expires_at)``.
|
||||
``expires_at`` is a Unix timestamp (seconds); 0.0 means unknown.
|
||||
"""
|
||||
if not _ACCOUNTS_FILE.exists():
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
try:
|
||||
with open(_ACCOUNTS_FILE, encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
logger.debug("Failed to read Antigravity accounts file: %s", exc)
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
|
||||
accounts = data.get("accounts", [])
|
||||
if not accounts:
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
|
||||
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
|
||||
schema_version = data.get("schemaVersion", 1)
|
||||
|
||||
if schema_version >= 4:
|
||||
# V4 schema: refresh = "refreshToken|projectId[|managedProjectId]"
|
||||
refresh_str = account.get("refresh", "")
|
||||
parts = refresh_str.split("|") if refresh_str else []
|
||||
refresh_token: str | None = parts[0] if parts else None
|
||||
project_id = parts[1] if len(parts) >= 2 and parts[1] else _DEFAULT_PROJECT_ID
|
||||
|
||||
access_token: str | None = account.get("access")
|
||||
expires_ms: int = account.get("expires", 0)
|
||||
expires_at = float(expires_ms) / 1000.0 if expires_ms else 0.0
|
||||
|
||||
# Treat near-expiry tokens as absent so _ensure_token() triggers a refresh.
|
||||
if access_token and expires_at and time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
|
||||
access_token = None
|
||||
expires_at = 0.0
|
||||
|
||||
return access_token, refresh_token, project_id, expires_at
|
||||
else:
|
||||
# V1–V3 schema: plain accessToken / refreshToken fields
|
||||
access_token = account.get("accessToken")
|
||||
refresh_token = account.get("refreshToken")
|
||||
# Estimate expiry from last_refresh + 1 h
|
||||
last_refresh_str: str | None = data.get("last_refresh")
|
||||
expires_at = 0.0
|
||||
if last_refresh_str:
|
||||
try:
|
||||
from datetime import datetime # noqa: PLC0415
|
||||
|
||||
ts = datetime.fromisoformat(last_refresh_str.replace("Z", "+00:00")).timestamp()
|
||||
expires_at = ts + 3600.0
|
||||
if time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
|
||||
access_token = None
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return access_token, refresh_token, _DEFAULT_PROJECT_ID, expires_at
|
||||
|
||||
|
||||
def _load_from_ide_db() -> tuple[str | None, str | None, float]:
|
||||
"""Extract ``(access_token, refresh_token, expires_at)`` from the IDE SQLite DB."""
|
||||
import base64 # noqa: PLC0415
|
||||
import sqlite3 # noqa: PLC0415
|
||||
|
||||
for db_path in (_IDE_STATE_DB_MAC, _IDE_STATE_DB_LINUX):
|
||||
if not db_path.exists():
|
||||
continue
|
||||
try:
|
||||
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
try:
|
||||
row = con.execute(
|
||||
"SELECT value FROM ItemTable WHERE key = ?",
|
||||
(_IDE_STATE_DB_KEY,),
|
||||
).fetchone()
|
||||
finally:
|
||||
con.close()
|
||||
if not row:
|
||||
continue
|
||||
|
||||
blob = base64.b64decode(row[0])
|
||||
candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
for candidate in candidates:
|
||||
try:
|
||||
padded = candidate + b"=" * (-len(candidate) % 4)
|
||||
inner = base64.urlsafe_b64decode(padded)
|
||||
except Exception:
|
||||
continue
|
||||
if not access_token:
|
||||
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
access_token = m.group(0).decode("ascii")
|
||||
if not refresh_token:
|
||||
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
refresh_token = m.group(0).decode("ascii")
|
||||
if access_token and refresh_token:
|
||||
break
|
||||
|
||||
if access_token:
|
||||
# Estimate expiry from DB mtime (IDE refreshes while running)
|
||||
mtime = db_path.stat().st_mtime
|
||||
expires_at = mtime + 3600.0
|
||||
return access_token, refresh_token, expires_at
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
|
||||
continue
|
||||
return None, None, 0.0
|
||||
|
||||
|
||||
def _do_token_refresh(refresh_token: str) -> tuple[str, float] | None:
|
||||
"""POST to Google OAuth endpoint and return ``(new_access_token, expires_at)``.
|
||||
|
||||
The client secret is sourced via ``get_antigravity_client_secret()`` (env var,
|
||||
config file, or npm package fallback). When unavailable the refresh is attempted
|
||||
without it — Google will reject it for web-app clients, but the npm fallback in
|
||||
``get_antigravity_client_secret()`` should ensure the secret is found at runtime.
|
||||
|
||||
Returns None when the HTTP request fails.
|
||||
"""
|
||||
from framework.config import get_antigravity_client_secret # noqa: PLC0415
|
||||
|
||||
client_secret = get_antigravity_client_secret()
|
||||
if not client_secret:
|
||||
logger.debug(
|
||||
"Antigravity client secret not configured — attempting refresh without it. "
|
||||
"Set ANTIGRAVITY_CLIENT_SECRET or run quickstart to configure."
|
||||
)
|
||||
|
||||
import urllib.error # noqa: PLC0415
|
||||
import urllib.parse # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
from framework.config import get_antigravity_client_id # noqa: PLC0415
|
||||
|
||||
params: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": get_antigravity_client_id(),
|
||||
}
|
||||
if client_secret:
|
||||
params["client_secret"] = client_secret
|
||||
body = urllib.parse.urlencode(params).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
|
||||
payload = json.loads(resp.read())
|
||||
access_token: str = payload["access_token"]
|
||||
expires_in: int = payload.get("expires_in", 3600)
|
||||
logger.debug("Antigravity token refreshed successfully")
|
||||
return access_token, time.time() + expires_in
|
||||
except Exception as exc:
|
||||
logger.debug("Antigravity token refresh failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message conversion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _clean_tool_name(name: str) -> str:
|
||||
"""Sanitize a tool name for the Antigravity function-calling schema."""
|
||||
name = re.sub(r"[/\s]", "_", name)
|
||||
if name and not (name[0].isalpha() or name[0] == "_"):
|
||||
name = "_" + name
|
||||
return name[:64]
|
||||
|
||||
|
||||
def _to_gemini_contents(
|
||||
messages: list[dict[str, Any]],
|
||||
thought_sigs: dict[str, str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-format messages to Gemini-style ``contents`` array."""
|
||||
# Pre-build a map tool_call_id → function_name from assistant messages.
|
||||
# Tool result messages (role="tool") only carry tool_call_id, not the name,
|
||||
# but Gemini requires functionResponse.name to match the functionCall.name.
|
||||
tc_id_to_name: dict[str, str] = {}
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
tc_id = tc.get("id")
|
||||
fn_name = tc.get("function", {}).get("name", "")
|
||||
if tc_id and fn_name:
|
||||
tc_id_to_name[tc_id] = fn_name
|
||||
|
||||
contents: list[dict[str, Any]] = []
|
||||
# Consecutive tool-result messages must be batched into one user turn.
|
||||
pending_tool_parts: list[dict[str, Any]] = []
|
||||
|
||||
def _flush_tool_results() -> None:
|
||||
if pending_tool_parts:
|
||||
contents.append({"role": "user", "parts": list(pending_tool_parts)})
|
||||
pending_tool_parts.clear()
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
continue # Handled via systemInstruction, not in contents.
|
||||
|
||||
if role == "tool":
|
||||
# OpenAI tool result → Gemini functionResponse part.
|
||||
result_str = content if isinstance(content, str) else str(content or "")
|
||||
tc_id = msg.get("tool_call_id", "")
|
||||
# Look up function name from the pre-built map; fall back to msg.name.
|
||||
fn_name = tc_id_to_name.get(tc_id) or msg.get("name", "")
|
||||
pending_tool_parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": fn_name,
|
||||
"id": tc_id,
|
||||
"response": {"content": result_str},
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
_flush_tool_results()
|
||||
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
parts: list[dict[str, Any]] = []
|
||||
|
||||
if isinstance(content, str) and content:
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if text:
|
||||
parts.append({"text": text})
|
||||
# Other block types (image_url etc.) skipped.
|
||||
|
||||
# Assistant messages may carry OpenAI-style tool_calls.
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
fn = tc.get("function", {})
|
||||
try:
|
||||
args = json.loads(fn.get("arguments", "{}") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
tc_id = tc.get("id", str(uuid.uuid4()))
|
||||
fc_part: dict[str, Any] = {
|
||||
"functionCall": {
|
||||
"name": fn.get("name", ""),
|
||||
"args": args,
|
||||
"id": tc_id,
|
||||
}
|
||||
}
|
||||
if thought_sigs:
|
||||
sig = thought_sigs.get(tc_id, "")
|
||||
if sig:
|
||||
fc_part["thoughtSignature"] = sig # part-level, not inside functionCall
|
||||
parts.append(fc_part)
|
||||
|
||||
if parts:
|
||||
contents.append({"role": gemini_role, "parts": parts})
|
||||
|
||||
_flush_tool_results()
|
||||
|
||||
# Gemini requires the first turn to be a user turn. Drop any leading
|
||||
# model messages so the API doesn't reject with a 400.
|
||||
while contents and contents[0].get("role") == "model":
|
||||
contents.pop(0)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _map_finish_reason(reason: str) -> str:
|
||||
return {"STOP": "stop", "MAX_TOKENS": "max_tokens", "OTHER": "tool_use"}.get(
|
||||
(reason or "").upper(), "stop"
|
||||
)
|
||||
|
||||
|
||||
def _parse_complete_response(raw: dict[str, Any], model: str) -> LLMResponse:
|
||||
"""Parse a non-streaming Antigravity response dict → LLMResponse."""
|
||||
payload: dict[str, Any] = raw.get("response", raw)
|
||||
candidates: list[dict[str, Any]] = payload.get("candidates", [])
|
||||
usage: dict[str, Any] = payload.get("usageMetadata", {})
|
||||
|
||||
text_parts: list[str] = []
|
||||
if candidates:
|
||||
for part in candidates[0].get("content", {}).get("parts", []):
|
||||
if "text" in part and not part.get("thought"):
|
||||
text_parts.append(part["text"])
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(text_parts),
|
||||
model=payload.get("modelVersion", model),
|
||||
input_tokens=usage.get("promptTokenCount", 0),
|
||||
output_tokens=usage.get("candidatesTokenCount", 0),
|
||||
stop_reason=_map_finish_reason(candidates[0].get("finishReason", "") if candidates else ""),
|
||||
raw_response=raw,
|
||||
)
|
||||
|
||||
|
||||
def _parse_sse_stream(
|
||||
response: Any,
|
||||
model: str,
|
||||
on_thought_signature: Callable[[str, str], None] | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Parse Antigravity SSE response line-by-line → StreamEvents.
|
||||
|
||||
Each SSE line looks like::
|
||||
|
||||
data: {"response": {"candidates": [...], "usageMetadata": {...}}, "traceId": "..."}
|
||||
"""
|
||||
accumulated = ""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = ""
|
||||
|
||||
for raw_line in response:
|
||||
line: str = raw_line.decode("utf-8", errors="replace").rstrip("\r\n")
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data_str = line[5:].strip()
|
||||
if not data_str or data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data: dict[str, Any] = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# The outer envelope is {"response": {...}, "traceId": "..."}.
|
||||
payload: dict[str, Any] = data.get("response", data)
|
||||
|
||||
usage = payload.get("usageMetadata", {})
|
||||
if usage:
|
||||
input_tokens = usage.get("promptTokenCount", input_tokens)
|
||||
output_tokens = usage.get("candidatesTokenCount", output_tokens)
|
||||
|
||||
for candidate in payload.get("candidates", []):
|
||||
fr = candidate.get("finishReason", "")
|
||||
if fr:
|
||||
finish_reason = fr
|
||||
|
||||
for part in candidate.get("content", {}).get("parts", []):
|
||||
if "text" in part and not part.get("thought"):
|
||||
delta: str = part["text"]
|
||||
accumulated += delta
|
||||
yield TextDeltaEvent(content=delta, snapshot=accumulated)
|
||||
elif "functionCall" in part:
|
||||
fc: dict[str, Any] = part["functionCall"]
|
||||
tool_use_id = fc.get("id") or str(uuid.uuid4())
|
||||
thought_sig = part.get("thoughtSignature", "") # sibling of functionCall
|
||||
if thought_sig and on_thought_signature:
|
||||
on_thought_signature(tool_use_id, thought_sig)
|
||||
args = fc.get("args", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=fc.get("name", ""),
|
||||
tool_input=args,
|
||||
)
|
||||
|
||||
if accumulated:
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(
|
||||
stop_reason=_map_finish_reason(finish_reason),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AntigravityProvider(LLMProvider):
|
||||
"""LLM provider for Google's internal Antigravity Code Assist gateway.
|
||||
|
||||
No local proxy required. Handles OAuth token refresh, Gemini-format
|
||||
request/response conversion, and SSE streaming directly.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "gemini-3-flash") -> None:
|
||||
# Strip any provider prefix ("openai/gemini-3-flash" → "gemini-3-flash").
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
self.model = model
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._refresh_token: str | None = None
|
||||
self._project_id: str = _DEFAULT_PROJECT_ID
|
||||
self._token_expires_at: float = 0.0
|
||||
self._thought_sigs: dict[str, str] = {} # tool_use_id → thoughtSignature
|
||||
|
||||
self._init_credentials()
|
||||
|
||||
# --- Credential management -------------------------------------------- #
|
||||
|
||||
def _init_credentials(self) -> None:
|
||||
"""Load credentials from the best available source."""
|
||||
access, refresh, project_id, expires_at = _load_from_json_file()
|
||||
if refresh:
|
||||
self._refresh_token = refresh
|
||||
self._project_id = project_id
|
||||
self._access_token = access
|
||||
self._token_expires_at = expires_at
|
||||
return
|
||||
|
||||
# Fall back to IDE state DB.
|
||||
access, refresh, expires_at = _load_from_ide_db()
|
||||
if access:
|
||||
self._access_token = access
|
||||
self._refresh_token = refresh
|
||||
self._token_expires_at = expires_at
|
||||
|
||||
def has_credentials(self) -> bool:
|
||||
"""Return True if any credential is available."""
|
||||
return bool(self._access_token or self._refresh_token)
|
||||
|
||||
def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing via OAuth if needed."""
|
||||
if (
|
||||
self._access_token
|
||||
and self._token_expires_at
|
||||
and time.time() < self._token_expires_at - _TOKEN_REFRESH_BUFFER_SECS
|
||||
):
|
||||
return self._access_token
|
||||
|
||||
if self._refresh_token:
|
||||
result = _do_token_refresh(self._refresh_token)
|
||||
if result:
|
||||
self._access_token, self._token_expires_at = result
|
||||
return self._access_token
|
||||
|
||||
if self._access_token:
|
||||
logger.warning("Using potentially stale Antigravity access token")
|
||||
return self._access_token
|
||||
|
||||
raise RuntimeError(
|
||||
"No valid Antigravity credentials. "
|
||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
||||
)
|
||||
|
||||
# --- Request building -------------------------------------------------- #
|
||||
|
||||
def _build_body(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool] | None,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
contents = _to_gemini_contents(messages, self._thought_sigs)
|
||||
inner: dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"maxOutputTokens": max_tokens},
|
||||
}
|
||||
if system:
|
||||
inner["systemInstruction"] = {"parts": [{"text": system}]}
|
||||
if tools:
|
||||
inner["tools"] = [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": _clean_tool_name(t.name),
|
||||
"description": t.description,
|
||||
"parameters": t.parameters
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
}
|
||||
]
|
||||
return {
|
||||
"project": self._project_id,
|
||||
"model": self.model,
|
||||
"request": inner,
|
||||
"requestType": "agent",
|
||||
"userAgent": "antigravity",
|
||||
"requestId": f"agent-{uuid.uuid4()}",
|
||||
}
|
||||
|
||||
# --- HTTP transport ---------------------------------------------------- #
|
||||
|
||||
def _post(self, body: dict[str, Any], *, streaming: bool) -> Any:
|
||||
"""POST to the Antigravity endpoint, falling back through the endpoint list."""
|
||||
import urllib.error # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
token = self._ensure_token()
|
||||
body_bytes = json.dumps(body).encode("utf-8")
|
||||
path = (
|
||||
"/v1internal:streamGenerateContent?alt=sse"
|
||||
if streaming
|
||||
else "/v1internal:generateContent"
|
||||
)
|
||||
headers = {
|
||||
**_BASE_HEADERS,
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if streaming:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for base_url in _ENDPOINTS:
|
||||
url = f"{base_url}{path}"
|
||||
req = urllib.request.Request(url, data=body_bytes, headers=headers, method="POST")
|
||||
try:
|
||||
return urllib.request.urlopen(req, timeout=120) # noqa: S310
|
||||
except urllib.error.HTTPError as exc:
|
||||
if exc.code in (401, 403) and self._refresh_token:
|
||||
# Token rejected — refresh once and retry this endpoint.
|
||||
result = _do_token_refresh(self._refresh_token)
|
||||
if result:
|
||||
self._access_token, self._token_expires_at = result
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
req2 = urllib.request.Request(
|
||||
url, data=body_bytes, headers=headers, method="POST"
|
||||
)
|
||||
try:
|
||||
return urllib.request.urlopen(req2, timeout=120) # noqa: S310
|
||||
except urllib.error.HTTPError as exc2:
|
||||
last_exc = exc2
|
||||
continue
|
||||
last_exc = exc
|
||||
continue
|
||||
elif exc.code >= 500:
|
||||
last_exc = exc
|
||||
continue
|
||||
# Include the API response body in the exception for easier debugging.
|
||||
try:
|
||||
err_body = exc.read().decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
err_body = "(unreadable)"
|
||||
raise RuntimeError(f"Antigravity HTTP {exc.code} from {url}: {err_body}") from exc
|
||||
except (urllib.error.URLError, OSError) as exc:
|
||||
last_exc = exc
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f"All Antigravity endpoints failed. Last error: {last_exc}"
|
||||
) from last_exc
|
||||
|
||||
# --- LLMProvider interface --------------------------------------------- #
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
if json_mode:
|
||||
suffix = "\n\nPlease respond with a valid JSON object."
|
||||
system = (system + suffix) if system else suffix.strip()
|
||||
|
||||
body = self._build_body(messages, system, tools, max_tokens)
|
||||
resp = self._post(body, streaming=False)
|
||||
return _parse_complete_response(json.loads(resp.read()), self.model)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
import asyncio # noqa: PLC0415
|
||||
import concurrent.futures # noqa: PLC0415
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
|
||||
|
||||
def _blocking_work() -> None:
|
||||
try:
|
||||
body = self._build_body(messages, system, tools, max_tokens)
|
||||
http_resp = self._post(body, streaming=True)
|
||||
for event in _parse_sse_stream(
|
||||
http_resp, self.model, self._thought_sigs.__setitem__
|
||||
):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, event)
|
||||
except Exception as exc:
|
||||
logger.error("Antigravity stream error: %s", exc)
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StreamErrorEvent(error=str(exc)))
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
fut = loop.run_in_executor(executor, _blocking_work)
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
await fut
|
||||
executor.shutdown(wait=False)
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Model capability checks for LLM providers.
|
||||
|
||||
Vision support rules are derived from official vendor documentation:
|
||||
- ZAI (z.ai): docs.z.ai/guides/vlm — GLM-4.6V variants are vision; GLM-5/4.6/4.7 are text-only
|
||||
- MiniMax: platform.minimax.io/docs — minimax-vl-01 is vision; M2.x are text-only
|
||||
- DeepSeek: api-docs.deepseek.com — deepseek-vl2 is vision; chat/reasoner are text-only
|
||||
- Cerebras: inference-docs.cerebras.ai — no vision models at all
|
||||
- Groq: console.groq.com/docs/vision — vision capable; treat as supported by default
|
||||
- Ollama/LM Studio/vLLM/llama.cpp: local runners denied by default; model names
|
||||
don't reliably indicate vision support, so users must configure explicitly
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _model_name(model: str) -> str:
|
||||
"""Return the bare model name after stripping any 'provider/' prefix."""
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
# Step 1: explicit vision allow-list — these always support images regardless
|
||||
# of what the provider-level rules say. Checked first so that e.g. glm-4.6v
|
||||
# is allowed even though glm-4.6 is denied.
|
||||
_VISION_ALLOW_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# ZAI/GLM vision models (docs.z.ai/guides/vlm)
|
||||
"glm-4v", # GLM-4V series (legacy)
|
||||
"glm-4.6v", # GLM-4.6V, GLM-4.6V-flash, GLM-4.6V-flashx
|
||||
# DeepSeek vision models
|
||||
"deepseek-vl", # deepseek-vl2, deepseek-vl2-small, deepseek-vl2-tiny
|
||||
# MiniMax vision model
|
||||
"minimax-vl", # minimax-vl-01
|
||||
)
|
||||
|
||||
# Step 2: provider-level deny — every model from this provider is text-only.
|
||||
_TEXT_ONLY_PROVIDER_PREFIXES: tuple[str, ...] = (
|
||||
# Cerebras: inference-docs.cerebras.ai lists only text models
|
||||
"cerebras/",
|
||||
# Local runners: model names don't reliably indicate vision support
|
||||
"ollama/",
|
||||
"ollama_chat/",
|
||||
"lm_studio/",
|
||||
"vllm/",
|
||||
"llamacpp/",
|
||||
)
|
||||
|
||||
# Step 3: per-model deny — text-only models within otherwise mixed providers.
|
||||
# Matched against the bare model name (provider prefix stripped, lower-cased).
|
||||
# The vision allow-list above is checked first, so vision variants of the same
|
||||
# family are already handled before these deny patterns are reached.
|
||||
_TEXT_ONLY_MODEL_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# --- ZAI / GLM family ---
|
||||
# text-only: glm-5, glm-4.6, glm-4.7, glm-4.5, zai-glm-*
|
||||
# vision: glm-4v, glm-4.6v (caught by allow-list above)
|
||||
"glm-5",
|
||||
"glm-4.6", # bare glm-4.6 is text-only; glm-4.6v is caught by allow-list
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"zai-glm",
|
||||
# --- DeepSeek ---
|
||||
# text-only: deepseek-chat, deepseek-coder, deepseek-reasoner
|
||||
# vision: deepseek-vl2 (caught by allow-list above)
|
||||
# Note: LiteLLM's deepseek handler may flatten content lists for some models;
|
||||
# VL models are allowed through and rely on LiteLLM's native VL support.
|
||||
"deepseek-chat",
|
||||
"deepseek-coder",
|
||||
"deepseek-reasoner",
|
||||
# --- MiniMax ---
|
||||
# text-only: minimax-m2.*, minimax-text-*, abab* (legacy)
|
||||
# vision: minimax-vl-01 (caught by allow-list above)
|
||||
"minimax-m2",
|
||||
"minimax-text",
|
||||
"abab",
|
||||
)
|
||||
|
||||
|
||||
def supports_image_tool_results(model: str) -> bool:
|
||||
"""Return whether *model* can receive image content in messages.
|
||||
|
||||
Used to gate both user-message images and tool-result image blocks.
|
||||
|
||||
Logic (checked in order):
|
||||
1. Vision allow-list → True (known vision model, skip all denies)
|
||||
2. Provider deny → False (entire provider is text-only)
|
||||
3. Model deny → False (specific text-only model within a mixed provider)
|
||||
4. Default → True (assume capable; unknown providers and models)
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
bare = _model_name(model_lower)
|
||||
|
||||
# 1. Explicit vision allow — takes priority over all denies
|
||||
if any(bare.startswith(p) for p in _VISION_ALLOW_BARE_PREFIXES):
|
||||
return True
|
||||
|
||||
# 2. Provider-level deny (all models from this provider are text-only)
|
||||
if any(model_lower.startswith(p) for p in _TEXT_ONLY_PROVIDER_PREFIXES):
|
||||
return False
|
||||
|
||||
# 3. Per-model deny (text-only variants within mixed-capability families)
|
||||
if any(bare.startswith(p) for p in _TEXT_ONLY_MODEL_BARE_PREFIXES):
|
||||
return False
|
||||
|
||||
# 5. Default: assume vision capable
|
||||
# Covers: OpenAI, Anthropic, Google, Mistral, Kimi, and other hosted providers
|
||||
return True
|
||||
+649
-12
@@ -7,9 +7,13 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
@@ -44,7 +48,10 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
"""
|
||||
try:
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_OAUTH_TOKEN_PREFIX
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_OAUTH_BETA_HEADER,
|
||||
ANTHROPIC_OAUTH_TOKEN_PREFIX,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Could not apply litellm Anthropic OAuth patch — litellm internals may have "
|
||||
@@ -69,9 +76,27 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
# Check both authorization header and x-api-key for OAuth tokens.
|
||||
# litellm's optionally_handle_anthropic_oauth only checks headers["authorization"],
|
||||
# but hive passes OAuth tokens via api_key — so litellm puts them into x-api-key.
|
||||
# Anthropic rejects OAuth tokens in x-api-key; they must go in Authorization: Bearer.
|
||||
auth = result.get("authorization", "")
|
||||
if auth.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
|
||||
x_api_key = result.get("x-api-key", "")
|
||||
oauth_prefix = f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"
|
||||
auth_is_oauth = auth.startswith(oauth_prefix)
|
||||
key_is_oauth = x_api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
|
||||
if auth_is_oauth or key_is_oauth:
|
||||
token = x_api_key if key_is_oauth else auth.removeprefix("Bearer ").strip()
|
||||
result.pop("x-api-key", None)
|
||||
result["authorization"] = f"Bearer {token}"
|
||||
# Merge the OAuth beta header with any existing beta headers.
|
||||
existing_beta = result.get("anthropic-beta", "")
|
||||
beta_parts = (
|
||||
[b.strip() for b in existing_beta.split(",") if b.strip()] if existing_beta else []
|
||||
)
|
||||
if ANTHROPIC_OAUTH_BETA_HEADER not in beta_parts:
|
||||
beta_parts.append(ANTHROPIC_OAUTH_BETA_HEADER)
|
||||
result["anthropic-beta"] = ",".join(beta_parts)
|
||||
return result
|
||||
|
||||
AnthropicModelInfo.validate_environment = _patched_validate_environment
|
||||
@@ -130,11 +155,15 @@ def _patch_litellm_metadata_nonetype() -> None:
|
||||
if litellm is not None:
|
||||
_patch_litellm_anthropic_oauth()
|
||||
_patch_litellm_metadata_nonetype()
|
||||
# Let litellm silently drop params unsupported by the target provider
|
||||
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
|
||||
litellm.drop_params = True
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
MINIMAX_API_BASE = "https://api.minimax.io/v1"
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
# Providers that accept cache_control on message content blocks.
|
||||
# Anthropic: native ephemeral caching. MiniMax & Z-AI/GLM: pass-through to their APIs.
|
||||
@@ -159,10 +188,69 @@ def _model_supports_cache_control(model: str) -> bool:
|
||||
# enforces a coding-agent whitelist that blocks unknown User-Agents.
|
||||
KIMI_API_BASE = "https://api.kimi.com/coding"
|
||||
|
||||
# Claude Code OAuth subscription: the Anthropic API requires a specific
|
||||
# User-Agent and a billing integrity header for OAuth-authenticated requests.
|
||||
CLAUDE_CODE_VERSION = "2.1.76"
|
||||
CLAUDE_CODE_USER_AGENT = f"claude-code/{CLAUDE_CODE_VERSION}"
|
||||
_CLAUDE_CODE_BILLING_SALT = "59cf53e54c78"
|
||||
|
||||
|
||||
def _sample_js_code_unit(text: str, idx: int) -> str:
|
||||
"""Return the character at UTF-16 code unit index *idx*, matching JS semantics."""
|
||||
encoded = text.encode("utf-16-le")
|
||||
unit_offset = idx * 2
|
||||
if unit_offset + 2 > len(encoded):
|
||||
return "0"
|
||||
code_unit = int.from_bytes(encoded[unit_offset : unit_offset + 2], "little")
|
||||
return chr(code_unit)
|
||||
|
||||
|
||||
def _claude_code_billing_header(messages: list[dict[str, Any]]) -> str:
|
||||
"""Build the billing integrity system block required by Anthropic's OAuth path."""
|
||||
# Find the first user message text
|
||||
first_text = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
first_text = content
|
||||
break
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text" and block.get("text"):
|
||||
first_text = block["text"]
|
||||
break
|
||||
if first_text:
|
||||
break
|
||||
|
||||
sampled = "".join(_sample_js_code_unit(first_text, i) for i in (4, 7, 20))
|
||||
version_hash = hashlib.sha256(
|
||||
f"{_CLAUDE_CODE_BILLING_SALT}{sampled}{CLAUDE_CODE_VERSION}".encode()
|
||||
).hexdigest()
|
||||
entrypoint = os.environ.get("CLAUDE_CODE_ENTRYPOINT", "").strip() or "cli"
|
||||
return (
|
||||
f"x-anthropic-billing-header: cc_version={CLAUDE_CODE_VERSION}.{version_hash[:3]}; "
|
||||
f"cc_entrypoint={entrypoint}; cch=00000;"
|
||||
)
|
||||
|
||||
|
||||
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
|
||||
# Conversation-structure issues are deterministic — long waits don't help.
|
||||
EMPTY_STREAM_MAX_RETRIES = 3
|
||||
EMPTY_STREAM_RETRY_DELAY = 1.0 # seconds
|
||||
OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS = (
|
||||
"no endpoints found that support tool use",
|
||||
"no endpoints available that support tool use",
|
||||
"provider routing",
|
||||
)
|
||||
OPENROUTER_TOOL_CALL_RE = re.compile(
|
||||
r"<\|tool_call_start\|>\s*(.*?)\s*<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS = 3600
|
||||
# OpenRouter routing can change over time, so tool-compat caching must expire.
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE: dict[str, float] = {}
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
@@ -205,6 +293,24 @@ def _prune_failed_request_dumps(max_files: int = MAX_FAILED_REQUEST_DUMPS) -> No
|
||||
pass # Best-effort — never block the caller
|
||||
|
||||
|
||||
def _remember_openrouter_tool_compat_model(model: str) -> None:
|
||||
"""Cache OpenRouter tool-compat fallback for a bounded time window."""
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE[model] = (
|
||||
time.monotonic() + OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS
|
||||
)
|
||||
|
||||
|
||||
def _is_openrouter_tool_compat_cached(model: str) -> bool:
|
||||
"""Return True when the cached OpenRouter compat entry is still fresh."""
|
||||
expires_at = OPENROUTER_TOOL_COMPAT_MODEL_CACHE.get(model)
|
||||
if expires_at is None:
|
||||
return False
|
||||
if expires_at <= time.monotonic():
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.pop(model, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _dump_failed_request(
|
||||
model: str,
|
||||
kwargs: dict[str, Any],
|
||||
@@ -408,11 +514,19 @@ class LiteLLMProvider(LLMProvider):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or self._default_api_base_for_model(_original_model)
|
||||
self.extra_kwargs = kwargs
|
||||
# Detect Claude Code OAuth subscription by checking the api_key prefix.
|
||||
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
|
||||
if self._claude_code_oauth:
|
||||
# Anthropic requires a specific User-Agent for OAuth requests.
|
||||
eh = self.extra_kwargs.setdefault("extra_headers", {})
|
||||
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
|
||||
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
|
||||
# several standard OpenAI params: max_output_tokens, stream_options.
|
||||
self._codex_backend = bool(
|
||||
self.api_base and "chatgpt.com/backend-api/codex" in self.api_base
|
||||
)
|
||||
# Antigravity routes through a local OpenAI-compatible proxy — no patches needed.
|
||||
self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base)
|
||||
|
||||
if litellm is None:
|
||||
raise ImportError(
|
||||
@@ -431,6 +545,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
model_lower = model.lower()
|
||||
if model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
|
||||
return MINIMAX_API_BASE
|
||||
if model_lower.startswith("openrouter/"):
|
||||
return OPENROUTER_API_BASE
|
||||
if model_lower.startswith("kimi/"):
|
||||
return KIMI_API_BASE
|
||||
if model_lower.startswith("hive/"):
|
||||
@@ -773,6 +889,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -834,11 +953,504 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
}
|
||||
|
||||
def _is_anthropic_model(self) -> bool:
|
||||
"""Return True when the configured model targets Anthropic."""
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("anthropic/") or model.startswith("claude-")
|
||||
|
||||
def _is_minimax_model(self) -> bool:
|
||||
"""Return True when the configured model targets MiniMax."""
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("minimax/") or model.startswith("minimax-")
|
||||
|
||||
def _is_openrouter_model(self) -> bool:
|
||||
"""Return True when the configured model targets OpenRouter."""
|
||||
model = (self.model or "").lower()
|
||||
if model.startswith("openrouter/"):
|
||||
return True
|
||||
api_base = (self.api_base or "").lower()
|
||||
return "openrouter.ai/api/v1" in api_base
|
||||
|
||||
def _should_use_openrouter_tool_compat(
|
||||
self,
|
||||
error: BaseException,
|
||||
tools: list[Tool] | None,
|
||||
) -> bool:
|
||||
"""Return True when OpenRouter rejects native tool use for the model."""
|
||||
if not tools or not self._is_openrouter_model():
|
||||
return False
|
||||
error_text = str(error).lower()
|
||||
return "openrouter" in error_text and any(
|
||||
snippet in error_text for snippet in OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||
"""Extract the first JSON object from a model response."""
|
||||
candidates = [text.strip()]
|
||||
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
fence_lines = stripped.splitlines()
|
||||
if len(fence_lines) >= 3:
|
||||
candidates.append("\n".join(fence_lines[1:-1]).strip())
|
||||
|
||||
decoder = json.JSONDecoder()
|
||||
for candidate in candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
for start_idx, char in enumerate(candidate):
|
||||
if char != "{":
|
||||
continue
|
||||
try:
|
||||
parsed, _ = decoder.raw_decode(candidate[start_idx:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_openrouter_tool_compat_response(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse JSON tool-compat output into assistant text and tool calls."""
|
||||
payload = self._extract_json_object(content)
|
||||
if payload is None:
|
||||
text_tool_content, text_tool_calls = self._parse_openrouter_text_tool_calls(
|
||||
content,
|
||||
tools,
|
||||
)
|
||||
if text_tool_calls:
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Parsed textual tool-call markers for %s",
|
||||
self.model,
|
||||
)
|
||||
return text_tool_content, text_tool_calls
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] %s returned non-JSON fallback content; "
|
||||
"treating it as plain text.",
|
||||
self.model,
|
||||
)
|
||||
return content.strip(), []
|
||||
|
||||
assistant_text = payload.get("assistant_response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("content")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = ""
|
||||
|
||||
tool_calls_raw = payload.get("tool_calls")
|
||||
if not tool_calls_raw and {"name", "arguments"} <= payload.keys():
|
||||
tool_calls_raw = [payload]
|
||||
elif isinstance(payload.get("tool_call"), dict):
|
||||
tool_calls_raw = [payload["tool_call"]]
|
||||
|
||||
if not isinstance(tool_calls_raw, list):
|
||||
tool_calls_raw = []
|
||||
|
||||
allowed_tool_names = {tool.name for tool in tools}
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
|
||||
for idx, raw_call in enumerate(tool_calls_raw):
|
||||
if not isinstance(raw_call, dict):
|
||||
continue
|
||||
|
||||
function_block = raw_call.get("function")
|
||||
function_name = (
|
||||
raw_call.get("name")
|
||||
or raw_call.get("tool_name")
|
||||
or (function_block.get("name") if isinstance(function_block, dict) else None)
|
||||
)
|
||||
if not isinstance(function_name, str) or function_name not in allowed_tool_names:
|
||||
if function_name:
|
||||
logger.warning(
|
||||
"[openrouter-tool-compat] Ignoring unknown tool '%s' for model %s",
|
||||
function_name,
|
||||
self.model,
|
||||
)
|
||||
continue
|
||||
|
||||
arguments = raw_call.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("tool_input")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("input")
|
||||
if arguments is None and isinstance(function_block, dict):
|
||||
arguments = function_block.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"_raw": arguments}
|
||||
elif not isinstance(arguments, dict):
|
||||
arguments = {"value": arguments}
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{idx}",
|
||||
"name": function_name,
|
||||
"input": arguments,
|
||||
}
|
||||
)
|
||||
|
||||
return assistant_text.strip(), tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _close_truncated_json_fragment(fragment: str) -> str:
|
||||
"""Close a truncated JSON fragment by balancing quotes/brackets."""
|
||||
stack: list[str] = []
|
||||
in_string = False
|
||||
escaped = False
|
||||
normalized = fragment.rstrip()
|
||||
|
||||
while normalized and normalized[-1] in ",:{[":
|
||||
normalized = normalized[:-1].rstrip()
|
||||
|
||||
for char in normalized:
|
||||
if in_string:
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif char == "\\":
|
||||
escaped = True
|
||||
elif char == '"':
|
||||
in_string = False
|
||||
continue
|
||||
|
||||
if char == '"':
|
||||
in_string = True
|
||||
elif char in "{[":
|
||||
stack.append(char)
|
||||
elif char == "}" and stack and stack[-1] == "{":
|
||||
stack.pop()
|
||||
elif char == "]" and stack and stack[-1] == "[":
|
||||
stack.pop()
|
||||
|
||||
if in_string:
|
||||
if escaped:
|
||||
normalized = normalized[:-1]
|
||||
normalized += '"'
|
||||
|
||||
for opener in reversed(stack):
|
||||
normalized += "}" if opener == "{" else "]"
|
||||
|
||||
return normalized
|
||||
|
||||
def _repair_truncated_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None:
|
||||
"""Try to recover a truncated JSON object from tool-call arguments."""
|
||||
stripped = raw_arguments.strip()
|
||||
if not stripped or stripped[0] != "{":
|
||||
return None
|
||||
|
||||
max_trim = min(len(stripped), 256)
|
||||
for trim in range(max_trim + 1):
|
||||
candidate = stripped[: len(stripped) - trim].rstrip()
|
||||
if not candidate:
|
||||
break
|
||||
candidate = self._close_truncated_json_fragment(candidate)
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]:
|
||||
"""Parse streamed tool arguments, repairing truncation when possible."""
|
||||
try:
|
||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
repaired = self._repair_truncated_tool_arguments(raw_arguments)
|
||||
if repaired is not None:
|
||||
logger.warning(
|
||||
"[tool-args] Recovered truncated arguments for %s on %s",
|
||||
tool_name,
|
||||
self.model,
|
||||
)
|
||||
return repaired
|
||||
|
||||
raise ValueError(
|
||||
f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)."
|
||||
)
|
||||
|
||||
def _parse_openrouter_text_tool_calls(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse textual OpenRouter tool calls into synthetic tool calls.
|
||||
|
||||
Supports both:
|
||||
- Marker wrapped payloads: <|tool_call_start|>...<|tool_call_end|>
|
||||
- Plain one-line tool calls: ask_user("...", ["..."])
|
||||
"""
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
segment_index = 0
|
||||
|
||||
for match in OPENROUTER_TOOL_CALL_RE.finditer(content):
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=match.group(1),
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
|
||||
stripped_content = OPENROUTER_TOOL_CALL_RE.sub("", content)
|
||||
retained_lines: list[str] = []
|
||||
for line in stripped_content.splitlines():
|
||||
stripped_line = line.strip()
|
||||
if not stripped_line:
|
||||
retained_lines.append(line)
|
||||
continue
|
||||
|
||||
candidate = stripped_line
|
||||
if candidate.startswith("`") and candidate.endswith("`") and len(candidate) > 1:
|
||||
candidate = candidate[1:-1].strip()
|
||||
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=candidate,
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
continue
|
||||
|
||||
retained_lines.append(line)
|
||||
|
||||
stripped_text = "\n".join(retained_lines).strip()
|
||||
return stripped_text, tool_calls
|
||||
|
||||
def _parse_openrouter_text_tool_call_block(
|
||||
self,
|
||||
block: str,
|
||||
tools_by_name: dict[str, Tool],
|
||||
compat_prefix: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse a single textual tool-call block like [tool(arg='x')]."""
|
||||
try:
|
||||
parsed = ast.parse(block.strip(), mode="eval").body
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
call_nodes = parsed.elts if isinstance(parsed, ast.List) else [parsed]
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for call_index, call_node in enumerate(call_nodes):
|
||||
if not isinstance(call_node, ast.Call) or not isinstance(call_node.func, ast.Name):
|
||||
continue
|
||||
|
||||
tool_name = call_node.func.id
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_input = self._parse_openrouter_text_tool_call_arguments(
|
||||
call_node=call_node,
|
||||
tool=tool,
|
||||
)
|
||||
except (ValueError, SyntaxError):
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{call_index}",
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _parse_openrouter_text_tool_call_arguments(
|
||||
call_node: ast.Call,
|
||||
tool: Tool,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse positional/keyword args from a textual tool call."""
|
||||
properties = tool.parameters.get("properties", {})
|
||||
positional_keys = list(properties.keys())
|
||||
tool_input: dict[str, Any] = {}
|
||||
|
||||
if len(call_node.args) > len(positional_keys):
|
||||
raise ValueError("Too many positional args for textual tool call")
|
||||
|
||||
for idx, arg_node in enumerate(call_node.args):
|
||||
tool_input[positional_keys[idx]] = ast.literal_eval(arg_node)
|
||||
|
||||
for kwarg in call_node.keywords:
|
||||
if kwarg.arg is None:
|
||||
raise ValueError("Star args are not supported in textual tool calls")
|
||||
tool_input[kwarg.arg] = ast.literal_eval(kwarg.value)
|
||||
|
||||
return tool_input
|
||||
|
||||
def _build_openrouter_tool_compat_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a JSON-only prompt for models without native tool support."""
|
||||
tool_specs = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
compat_instruction = (
|
||||
"Tool compatibility mode is active because this OpenRouter model does not support "
|
||||
"native function calling on the routed provider.\n"
|
||||
"Return exactly one JSON object and nothing else.\n"
|
||||
'Schema: {"assistant_response": string, '
|
||||
'"tool_calls": [{"name": string, "arguments": object}]}\n'
|
||||
"Rules:\n"
|
||||
"- If a tool is required, put one or more entries in tool_calls "
|
||||
"and do not invent tool results.\n"
|
||||
"- If no tool is required, set tool_calls to [] and put the full "
|
||||
"answer in assistant_response.\n"
|
||||
"- Only use tool names from the allowed tool list.\n"
|
||||
"- arguments must always be valid JSON objects.\n"
|
||||
f"Allowed tools:\n{json.dumps(tool_specs, ensure_ascii=True)}"
|
||||
)
|
||||
compat_system = compat_instruction if not system else f"{system}\n\n{compat_instruction}"
|
||||
|
||||
full_messages: list[dict[str, Any]] = [{"role": "system", "content": compat_system}]
|
||||
full_messages.extend(messages)
|
||||
return [
|
||||
message
|
||||
for message in full_messages
|
||||
if not (
|
||||
message.get("role") == "assistant"
|
||||
and not message.get("content")
|
||||
and not message.get("tool_calls")
|
||||
)
|
||||
]
|
||||
|
||||
async def _acomplete_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools."""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(messages, system, tools)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
assistant_text, tool_calls = self._parse_openrouter_tool_compat_response(
|
||||
raw_content,
|
||||
tools,
|
||||
)
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
stop_reason = "tool_calls" if tool_calls else (response.choices[0].finish_reason or "stop")
|
||||
|
||||
return LLMResponse(
|
||||
content=assistant_text,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={
|
||||
"compat_mode": "openrouter_tool_emulation",
|
||||
"tool_calls": tool_calls,
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
|
||||
async def _stream_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback stream for OpenRouter models without native tool support."""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Using compatibility mode for %s",
|
||||
self.model,
|
||||
)
|
||||
try:
|
||||
response = await self._acomplete_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
raw_response = response.raw_response if isinstance(response.raw_response, dict) else {}
|
||||
tool_calls = raw_response.get("tool_calls", [])
|
||||
|
||||
if response.content:
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=tool_call["id"],
|
||||
tool_name=tool_call["name"],
|
||||
tool_input=tool_call["input"],
|
||||
)
|
||||
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _stream_via_nonstream_completion(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -882,12 +1494,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
tool_calls = msg.tool_calls or []
|
||||
|
||||
for tc in tool_calls:
|
||||
parsed_args: Any
|
||||
args = tc.function.arguments if tc.function else ""
|
||||
try:
|
||||
parsed_args = json.loads(args) if args else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed_args = {"_raw": args}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
args,
|
||||
tc.function.name if tc.function else "",
|
||||
)
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=getattr(tc, "id", ""),
|
||||
tool_name=tc.function.name if tc.function else "",
|
||||
@@ -946,7 +1557,20 @@ class LiteLLMProvider(LLMProvider):
|
||||
yield event
|
||||
return
|
||||
|
||||
if tools and self._is_openrouter_model() and _is_openrouter_tool_compat_cached(self.model):
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -984,9 +1608,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
|
||||
# Only include it for providers that support it.
|
||||
if not self._is_anthropic_model():
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
@@ -1092,10 +1719,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
if choice.finish_reason:
|
||||
stream_finish_reason = choice.finish_reason
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
tc_data.get("arguments", ""),
|
||||
tc_data.get("name", ""),
|
||||
)
|
||||
tail_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
@@ -1276,6 +1903,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if self._should_use_openrouter_tool_compat(e, tools):
|
||||
_remember_openrouter_tool_compat_model(self.model)
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools or [],
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
if _is_stream_transient_error(e) and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
|
||||
@@ -45,6 +45,8 @@ class ToolResult:
|
||||
tool_use_id: str
|
||||
content: str
|
||||
is_error: bool = False
|
||||
image_content: list[dict[str, Any]] | None = None
|
||||
is_skill_content: bool = False # AS-10: marks activated skill body, protected from pruning
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
|
||||
@@ -208,7 +208,12 @@ def configure_logging(
|
||||
|
||||
# Suppress noisy LiteLLM INFO logs (model/provider line + Provider List URL
|
||||
# printed on every single completion call). Warnings and errors still show.
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
# Honour LITELLM_LOG env var so users can opt-in to debug output.
|
||||
_litellm_level = os.getenv("LITELLM_LOG", "").upper()
|
||||
if _litellm_level and hasattr(logging, _litellm_level):
|
||||
logging.getLogger("LiteLLM").setLevel(getattr(logging, _litellm_level))
|
||||
else:
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
|
||||
# When in JSON mode, configure known third-party loggers to use JSON formatter
|
||||
# This ensures libraries like LiteLLM, httpcore also output clean JSON
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""MCP Client for connecting to Model Context Protocol servers.
|
||||
|
||||
This module provides a client for connecting to MCP servers and invoking their tools.
|
||||
Supports both STDIO and HTTP transports using the official MCP Python SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP Python SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -22,7 +22,7 @@ class MCPServerConfig:
|
||||
"""Configuration for an MCP server connection."""
|
||||
|
||||
name: str
|
||||
transport: Literal["stdio", "http"]
|
||||
transport: Literal["stdio", "http", "unix", "sse"]
|
||||
|
||||
# For STDIO transport
|
||||
command: str | None = None
|
||||
@@ -33,6 +33,7 @@ class MCPServerConfig:
|
||||
# For HTTP transport
|
||||
url: str | None = None
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
socket_path: str | None = None
|
||||
|
||||
# Optional metadata
|
||||
description: str = ""
|
||||
@@ -52,7 +53,7 @@ class MCPClient:
|
||||
"""
|
||||
Client for communicating with MCP servers.
|
||||
|
||||
Supports both STDIO and HTTP transports using the official MCP SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP SDK.
|
||||
Manages the connection lifecycle and provides methods to list and invoke tools.
|
||||
"""
|
||||
|
||||
@@ -68,6 +69,7 @@ class MCPClient:
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._stdio_context = None # Context manager for stdio_client
|
||||
self._sse_context = None # Context manager for sse_client
|
||||
self._errlog_handle = None # Track errlog file handle for cleanup
|
||||
self._http_client: httpx.Client | None = None
|
||||
self._tools: dict[str, MCPTool] = {}
|
||||
@@ -141,6 +143,10 @@ class MCPClient:
|
||||
self._connect_stdio()
|
||||
elif self.config.transport == "http":
|
||||
self._connect_http()
|
||||
elif self.config.transport == "unix":
|
||||
self._connect_unix()
|
||||
elif self.config.transport == "sse":
|
||||
self._connect_sse()
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport: {self.config.transport}")
|
||||
|
||||
@@ -266,10 +272,94 @@ class MCPClient:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_unix(self) -> None:
|
||||
"""Connect to MCP server via UNIX domain socket transport."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for UNIX transport")
|
||||
if not self.config.socket_path:
|
||||
raise ValueError("socket_path is required for UNIX transport")
|
||||
|
||||
self._http_client = httpx.Client(
|
||||
base_url=self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
transport=httpx.HTTPTransport(uds=self.config.socket_path),
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
"Connected to MCP server '%s' via UNIX socket at %s",
|
||||
self.config.name,
|
||||
self.config.socket_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_sse(self) -> None:
|
||||
"""Connect to MCP server via SSE transport using MCP SDK with persistent session."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for SSE transport")
|
||||
|
||||
try:
|
||||
loop_started = threading.Event()
|
||||
connection_ready = threading.Event()
|
||||
connection_error = []
|
||||
|
||||
def run_event_loop():
|
||||
"""Run event loop in background thread."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
loop_started.set()
|
||||
|
||||
async def init_connection():
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
self._sse_context = sse_client(
|
||||
self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
) = await self._sse_context.__aenter__()
|
||||
|
||||
self._session = ClientSession(self._read_stream, self._write_stream)
|
||||
await self._session.__aenter__()
|
||||
await self._session.initialize()
|
||||
|
||||
connection_ready.set()
|
||||
except Exception as e:
|
||||
connection_error.append(e)
|
||||
connection_ready.set()
|
||||
|
||||
self._loop.create_task(init_connection())
|
||||
self._loop.run_forever()
|
||||
|
||||
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
|
||||
self._loop_thread.start()
|
||||
|
||||
loop_started.wait(timeout=5)
|
||||
if not loop_started.is_set():
|
||||
raise RuntimeError("Event loop failed to start")
|
||||
|
||||
connection_ready.wait(timeout=10)
|
||||
if connection_error:
|
||||
raise connection_error[0]
|
||||
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via SSE")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
def _discover_tools(self) -> None:
|
||||
"""Discover available tools from the MCP server."""
|
||||
try:
|
||||
if self.config.transport == "stdio":
|
||||
if self.config.transport in {"stdio", "sse"}:
|
||||
tools_list = self._run_async(self._list_tools_stdio_async())
|
||||
else:
|
||||
tools_list = self._list_tools_http()
|
||||
@@ -371,9 +461,37 @@ class MCPClient:
|
||||
if self.config.transport == "stdio":
|
||||
with self._stdio_call_lock:
|
||||
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
elif self.config.transport == "sse":
|
||||
return self._call_tool_with_retry(
|
||||
lambda: self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
)
|
||||
elif self.config.transport == "unix":
|
||||
return self._call_tool_with_retry(lambda: self._call_tool_http(tool_name, arguments))
|
||||
else:
|
||||
return self._call_tool_http(tool_name, arguments)
|
||||
|
||||
def _call_tool_with_retry(self, call: Any) -> Any:
|
||||
"""Retry transient MCP transport failures once after reconnecting."""
|
||||
if self.config.transport == "stdio":
|
||||
return call()
|
||||
|
||||
if self.config.transport not in {"unix", "sse"}:
|
||||
return call()
|
||||
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as original_error:
|
||||
logger.warning(
|
||||
"Retrying MCP tool call after transport error from '%s': %s",
|
||||
self.config.name,
|
||||
original_error,
|
||||
)
|
||||
self._reconnect()
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as retry_error:
|
||||
raise original_error from retry_error
|
||||
|
||||
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call tool via STDIO protocol using persistent session."""
|
||||
if not self._session:
|
||||
@@ -391,17 +509,30 @@ class MCPClient:
|
||||
error_text = content_item.text
|
||||
raise RuntimeError(f"MCP tool '{tool_name}' failed: {error_text}")
|
||||
|
||||
# Extract content
|
||||
# Extract content — preserve image blocks alongside text
|
||||
if result.content:
|
||||
# MCP returns content as a list of content items
|
||||
if len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
# Check if it's a text content item
|
||||
if hasattr(content_item, "text"):
|
||||
return content_item.text
|
||||
elif hasattr(content_item, "data"):
|
||||
return content_item.data
|
||||
return result.content
|
||||
text_parts: list[str] = []
|
||||
image_parts: list[dict[str, Any]] = []
|
||||
for item in result.content:
|
||||
if hasattr(item, "text"):
|
||||
text_parts.append(item.text)
|
||||
elif hasattr(item, "data") and hasattr(item, "mimeType"):
|
||||
# MCP ImageContent — preserve as structured image block
|
||||
image_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{item.mimeType};base64,{item.data}",
|
||||
},
|
||||
}
|
||||
)
|
||||
elif hasattr(item, "data"):
|
||||
text_parts.append(str(item.data))
|
||||
|
||||
text = "\n".join(text_parts) if text_parts else ""
|
||||
if image_parts:
|
||||
return {"_text": text, "_images": image_parts}
|
||||
return text if text else None
|
||||
|
||||
return None
|
||||
|
||||
@@ -433,18 +564,24 @@ class MCPClient:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
"""Reconnect to the configured MCP server."""
|
||||
logger.info(f"Reconnecting to MCP server '{self.config.name}'...")
|
||||
self.disconnect()
|
||||
self.connect()
|
||||
|
||||
_CLEANUP_TIMEOUT = 10
|
||||
_THREAD_JOIN_TIMEOUT = 12
|
||||
|
||||
async def _cleanup_stdio_async(self) -> None:
|
||||
"""Async cleanup for STDIO session and context managers.
|
||||
"""Async cleanup for persistent MCP session and context managers.
|
||||
|
||||
Cleanup order is critical:
|
||||
- The session must be closed BEFORE the stdio_context because the session
|
||||
depends on the streams provided by stdio_context.
|
||||
- This mirrors the initialization order in _connect_stdio(), where
|
||||
stdio_context is entered first (providing streams), then the session is
|
||||
created with those streams and entered.
|
||||
- The session must be closed BEFORE the transport context manager because the
|
||||
session depends on the streams provided by that context.
|
||||
- This mirrors the initialization order in _connect_stdio() / _connect_sse(),
|
||||
where the transport context is entered first (providing streams), then the
|
||||
session is created with those streams and entered.
|
||||
- Do not change this ordering without carefully considering these dependencies.
|
||||
"""
|
||||
# First: close session (depends on stdio_context streams)
|
||||
@@ -477,6 +614,16 @@ class MCPClient:
|
||||
finally:
|
||||
self._stdio_context = None
|
||||
|
||||
try:
|
||||
if self._sse_context:
|
||||
await self._sse_context.__aexit__(None, None, None)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("SSE context cleanup was cancelled; proceeding with best-effort shutdown")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing SSE context: {e}")
|
||||
finally:
|
||||
self._sse_context = None
|
||||
|
||||
# Third: close errlog file handle if we opened one
|
||||
if self._errlog_handle is not None:
|
||||
try:
|
||||
@@ -552,6 +699,7 @@ class MCPClient:
|
||||
# Setting None to None is safe and ensures clean state.
|
||||
self._session = None
|
||||
self._stdio_context = None
|
||||
self._sse_context = None
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._loop = None
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Shared MCP client connection management."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPConnectionManager:
|
||||
"""Process-wide MCP client pool keyed by server name."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool: dict[str, MCPClient] = {}
|
||||
self._refcounts: dict[str, int] = {}
|
||||
self._configs: dict[str, MCPServerConfig] = {}
|
||||
self._pool_lock = threading.Lock()
|
||||
# Transition events keep callers from racing a connect/reconnect/disconnect.
|
||||
self._transitions: dict[str, threading.Event] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPConnectionManager":
|
||||
"""Return the process-level singleton instance."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
def _is_connected(client: MCPClient | None) -> bool:
|
||||
return bool(client and getattr(client, "_connected", False))
|
||||
|
||||
def acquire(self, config: MCPServerConfig) -> MCPClient:
|
||||
"""Get or create a shared connection and increment its refcount."""
|
||||
server_name = config.name
|
||||
|
||||
while True:
|
||||
should_connect = False
|
||||
transition_event: threading.Event | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
client = self._pool.get(server_name)
|
||||
if self._is_connected(client) and server_name not in self._transitions:
|
||||
new_refcount = self._refcounts.get(server_name, 0) + 1
|
||||
self._refcounts[server_name] = new_refcount
|
||||
self._configs[server_name] = config
|
||||
logger.debug(
|
||||
"Reusing pooled connection for MCP server '%s' (refcount=%d)",
|
||||
server_name,
|
||||
new_refcount,
|
||||
)
|
||||
return client
|
||||
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
self._configs[server_name] = config
|
||||
should_connect = True
|
||||
|
||||
if not should_connect:
|
||||
transition_event.wait()
|
||||
continue
|
||||
|
||||
client = MCPClient(config)
|
||||
try:
|
||||
client.connect()
|
||||
except Exception:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
if (
|
||||
server_name not in self._pool
|
||||
and self._refcounts.get(server_name, 0) <= 0
|
||||
):
|
||||
self._configs.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool[server_name] = client
|
||||
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
|
||||
self._configs[server_name] = config
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return client
|
||||
|
||||
client.disconnect()
|
||||
|
||||
def release(self, server_name: str) -> None:
|
||||
"""Decrement refcount and disconnect when the last user releases."""
|
||||
while True:
|
||||
disconnect_client: MCPClient | None = None
|
||||
transition_event: threading.Event | None = None
|
||||
should_disconnect = False
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
refcount = self._refcounts.get(server_name, 0)
|
||||
if refcount <= 0:
|
||||
return
|
||||
if refcount > 1:
|
||||
self._refcounts[server_name] = refcount - 1
|
||||
return
|
||||
|
||||
disconnect_client = self._pool.pop(server_name, None)
|
||||
self._refcounts.pop(server_name, None)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
should_disconnect = True
|
||||
|
||||
if not should_disconnect:
|
||||
transition_event.wait()
|
||||
continue
|
||||
|
||||
try:
|
||||
if disconnect_client is not None:
|
||||
disconnect_client.disconnect()
|
||||
finally:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return
|
||||
|
||||
def health_check(self, server_name: str) -> bool:
|
||||
"""Return True when the pooled connection appears healthy."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
client = self._pool.get(server_name)
|
||||
config = self._configs.get(server_name)
|
||||
break
|
||||
|
||||
transition_event.wait()
|
||||
|
||||
if client is None or config is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if config.transport == "stdio":
|
||||
client.list_tools()
|
||||
return True
|
||||
|
||||
if not config.url:
|
||||
return False
|
||||
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"base_url": config.url,
|
||||
"headers": config.headers,
|
||||
"timeout": 5.0,
|
||||
}
|
||||
if config.transport == "unix":
|
||||
if not config.socket_path:
|
||||
return False
|
||||
client_kwargs["transport"] = httpx.HTTPTransport(uds=config.socket_path)
|
||||
|
||||
with httpx.Client(**client_kwargs) as http_client:
|
||||
response = http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def reconnect(self, server_name: str) -> MCPClient:
|
||||
"""Force a disconnect and replace the pooled client with a fresh one."""
|
||||
while True:
|
||||
transition_event: threading.Event | None = None
|
||||
old_client: MCPClient | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
config = self._configs.get(server_name)
|
||||
if config is None:
|
||||
raise KeyError(f"Unknown MCP server: {server_name}")
|
||||
old_client = self._pool.get(server_name)
|
||||
refcount = self._refcounts.get(server_name, 0)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
break
|
||||
|
||||
transition_event.wait()
|
||||
|
||||
if old_client is not None:
|
||||
old_client.disconnect()
|
||||
|
||||
new_client = MCPClient(config)
|
||||
try:
|
||||
new_client.connect()
|
||||
except Exception:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool.pop(server_name, None)
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool[server_name] = new_client
|
||||
self._refcounts[server_name] = max(refcount, 1)
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return new_client
|
||||
|
||||
new_client.disconnect()
|
||||
return self.acquire(config)
|
||||
|
||||
def cleanup_all(self) -> None:
|
||||
"""Disconnect all pooled clients and clear manager state."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
if self._transitions:
|
||||
pending = list(self._transitions.values())
|
||||
else:
|
||||
cleanup_events = {name: threading.Event() for name in self._pool}
|
||||
clients = list(self._pool.items())
|
||||
self._transitions.update(cleanup_events)
|
||||
self._pool.clear()
|
||||
self._refcounts.clear()
|
||||
self._configs.clear()
|
||||
break
|
||||
|
||||
for event in pending:
|
||||
event.wait()
|
||||
|
||||
for _server_name, client in clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self._pool_lock:
|
||||
for server_name, event in cleanup_events.items():
|
||||
current = self._transitions.get(server_name)
|
||||
if current is event:
|
||||
self._transitions.pop(server_name, None)
|
||||
event.set()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Pre-load validation for agent graphs.
|
||||
|
||||
Runs structural and credential checks before MCP servers are spawned.
|
||||
Runs structural, credential, and skill-trust checks before MCP servers are spawned.
|
||||
Fails fast with actionable error messages.
|
||||
"""
|
||||
|
||||
@@ -169,6 +169,9 @@ def run_preload_validation(
|
||||
1. Graph structure (includes GCU subagent-only checks) — non-recoverable
|
||||
2. Credentials — potentially recoverable via interactive setup
|
||||
|
||||
Skill discovery and trust gating (AS-13) happen later in runner._setup()
|
||||
so they have access to agent-level skill configuration.
|
||||
|
||||
Raises PreloadValidationError for structural issues.
|
||||
Raises CredentialError for credential issues.
|
||||
"""
|
||||
|
||||
@@ -552,6 +552,319 @@ def get_kimi_code_token() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Antigravity subscription token helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Antigravity IDE (native macOS/Linux app) stores OAuth tokens in its
|
||||
# VSCode-style SQLite state database under the key
|
||||
# "antigravityUnifiedStateSync.oauthToken" as a base64-encoded protobuf blob.
|
||||
ANTIGRAVITY_IDE_STATE_DB = (
|
||||
Path.home()
|
||||
/ "Library"
|
||||
/ "Application Support"
|
||||
/ "Antigravity"
|
||||
/ "User"
|
||||
/ "globalStorage"
|
||||
/ "state.vscdb"
|
||||
)
|
||||
# Linux fallback for the IDE state DB
|
||||
ANTIGRAVITY_IDE_STATE_DB_LINUX = (
|
||||
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
|
||||
)
|
||||
# Antigravity credentials stored by native OAuth implementation
|
||||
ANTIGRAVITY_AUTH_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
|
||||
ANTIGRAVITY_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_ANTIGRAVITY_TOKEN_LIFETIME_SECS = 3600 # Google access tokens expire in 1 hour
|
||||
_ANTIGRAVITY_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
|
||||
|
||||
|
||||
def _read_antigravity_ide_credentials() -> dict | None:
|
||||
"""Read credentials from the Antigravity IDE's SQLite state database.
|
||||
|
||||
The Antigravity desktop IDE (VSCode-based) stores its OAuth token as a
|
||||
base64-encoded protobuf blob in a SQLite database. The access token is
|
||||
a standard Google OAuth ``ya29.*`` bearer token.
|
||||
|
||||
Returns:
|
||||
Dict with ``accessToken`` and optionally ``refreshToken`` keys,
|
||||
plus ``_source: "ide"`` to skip file-based save on refresh.
|
||||
Returns None if the database is absent or the key is not found.
|
||||
"""
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
for db_path in (ANTIGRAVITY_IDE_STATE_DB, ANTIGRAVITY_IDE_STATE_DB_LINUX):
|
||||
if not db_path.exists():
|
||||
continue
|
||||
try:
|
||||
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
try:
|
||||
row = con.execute(
|
||||
"SELECT value FROM ItemTable WHERE key = ?",
|
||||
(_ANTIGRAVITY_IDE_STATE_DB_KEY,),
|
||||
).fetchone()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
if not row:
|
||||
continue
|
||||
|
||||
import base64
|
||||
|
||||
blob = base64.b64decode(row[0])
|
||||
|
||||
# The protobuf blob contains the access token (ya29.*) and
|
||||
# refresh token (1//*) as length-prefixed UTF-8 strings.
|
||||
# Decode the inner base64 layer and extract with regex.
|
||||
inner_b64_candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
for candidate in inner_b64_candidates:
|
||||
try:
|
||||
padded = candidate + b"=" * (-len(candidate) % 4)
|
||||
inner = base64.urlsafe_b64decode(padded)
|
||||
except Exception:
|
||||
continue
|
||||
if not access_token:
|
||||
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
access_token = m.group(0).decode("ascii")
|
||||
if not refresh_token:
|
||||
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
refresh_token = m.group(0).decode("ascii")
|
||||
if access_token and refresh_token:
|
||||
break
|
||||
|
||||
if access_token:
|
||||
return {
|
||||
"accounts": [
|
||||
{
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token or "",
|
||||
}
|
||||
],
|
||||
"_source": "ide",
|
||||
"_db_path": str(db_path),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _read_antigravity_credentials() -> dict | None:
|
||||
"""Read Antigravity auth data from all supported credential sources.
|
||||
|
||||
Checks in order:
|
||||
1. Antigravity IDE SQLite state database (native macOS/Linux app)
|
||||
2. Native OAuth credentials file (~/.hive/antigravity-accounts.json)
|
||||
|
||||
Returns:
|
||||
Auth data dict with an ``accounts`` list on success, None otherwise.
|
||||
"""
|
||||
# 1. Native Antigravity IDE (primary on macOS)
|
||||
ide_creds = _read_antigravity_ide_credentials()
|
||||
if ide_creds:
|
||||
return ide_creds
|
||||
|
||||
# 2. Native OAuth credentials file
|
||||
if ANTIGRAVITY_AUTH_FILE.exists():
|
||||
try:
|
||||
with open(ANTIGRAVITY_AUTH_FILE, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
accounts = data.get("accounts", [])
|
||||
if accounts and isinstance(accounts[0], dict):
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _is_antigravity_token_expired(auth_data: dict) -> bool:
|
||||
"""Check whether the Antigravity access token is expired or near expiry.
|
||||
|
||||
For IDE-sourced credentials: uses the state DB's mtime as last_refresh
|
||||
since the IDE keeps the DB fresh while it's running.
|
||||
For JSON-sourced credentials: uses the ``last_refresh`` field or file mtime.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
now = time.time()
|
||||
|
||||
if auth_data.get("_source") == "ide":
|
||||
# The IDE refreshes tokens automatically while running.
|
||||
# Use the DB file's mtime as a proxy for when the token was last updated.
|
||||
try:
|
||||
db_path = Path(auth_data.get("_db_path", str(ANTIGRAVITY_IDE_STATE_DB)))
|
||||
last_refresh: float = db_path.stat().st_mtime
|
||||
except OSError:
|
||||
return True
|
||||
expires_at = last_refresh + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
|
||||
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
|
||||
|
||||
last_refresh_val: float | str | None = auth_data.get("last_refresh")
|
||||
if last_refresh_val is None:
|
||||
try:
|
||||
last_refresh_val = ANTIGRAVITY_AUTH_FILE.stat().st_mtime
|
||||
except OSError:
|
||||
return True
|
||||
elif isinstance(last_refresh_val, str):
|
||||
try:
|
||||
last_refresh_val = datetime.fromisoformat(
|
||||
last_refresh_val.replace("Z", "+00:00")
|
||||
).timestamp()
|
||||
except (ValueError, TypeError):
|
||||
return True
|
||||
|
||||
expires_at = float(last_refresh_val) + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
|
||||
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
|
||||
|
||||
|
||||
def _refresh_antigravity_token(refresh_token: str) -> dict | None:
|
||||
"""Refresh the Antigravity access token via Google OAuth.
|
||||
|
||||
POSTs form-encoded ``grant_type=refresh_token`` to the Google token
|
||||
endpoint using Antigravity's public OAuth client ID.
|
||||
|
||||
Returns:
|
||||
Parsed response dict (containing ``access_token``) on success,
|
||||
None on any error.
|
||||
"""
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
from framework.config import get_antigravity_client_id, get_antigravity_client_secret
|
||||
|
||||
client_id = get_antigravity_client_id()
|
||||
client_secret = get_antigravity_client_secret()
|
||||
params: dict = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
params["client_secret"] = client_secret
|
||||
|
||||
data = urllib.parse.urlencode(params).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
ANTIGRAVITY_OAUTH_TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
|
||||
return json.loads(resp.read())
|
||||
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
|
||||
logger.debug("Antigravity token refresh failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _save_refreshed_antigravity_credentials(auth_data: dict, token_data: dict) -> None:
|
||||
"""Write refreshed tokens back to the Antigravity JSON credentials file.
|
||||
|
||||
Skipped for IDE-sourced credentials (the IDE manages its own DB).
|
||||
Updates ``accounts[0].accessToken`` (and ``refreshToken`` if present),
|
||||
then persists ``last_refresh`` as an ISO-8601 UTC string.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# IDE manages its own state — we do not write back to its SQLite DB
|
||||
if auth_data.get("_source") == "ide":
|
||||
return
|
||||
|
||||
try:
|
||||
accounts = auth_data.get("accounts", [])
|
||||
if not accounts:
|
||||
return
|
||||
account = accounts[0]
|
||||
account["accessToken"] = token_data["access_token"]
|
||||
if "refresh_token" in token_data:
|
||||
account["refreshToken"] = token_data["refresh_token"]
|
||||
auth_data["accounts"] = accounts
|
||||
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
|
||||
|
||||
ANTIGRAVITY_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = os.open(ANTIGRAVITY_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(auth_data, f, indent=2)
|
||||
logger.debug("Antigravity credentials refreshed and saved")
|
||||
except (OSError, KeyError) as exc:
|
||||
logger.debug("Failed to save refreshed Antigravity credentials: %s", exc)
|
||||
|
||||
|
||||
def get_antigravity_token() -> str | None:
|
||||
"""Get the OAuth access token from an Antigravity subscription.
|
||||
|
||||
Credential sources checked in order:
|
||||
1. Antigravity IDE SQLite state DB (native app, macOS/Linux)
|
||||
2. antigravity-auth CLI JSON file
|
||||
|
||||
For IDE credentials the token is read directly (the IDE refreshes it
|
||||
automatically while running). For JSON credentials an automatic OAuth
|
||||
refresh is attempted when the token is near expiry.
|
||||
|
||||
Returns:
|
||||
The ``ya29.*`` Google OAuth access token, or None if unavailable.
|
||||
"""
|
||||
auth_data = _read_antigravity_credentials()
|
||||
if not auth_data:
|
||||
return None
|
||||
|
||||
accounts = auth_data.get("accounts", [])
|
||||
if not accounts:
|
||||
return None
|
||||
account = accounts[0]
|
||||
|
||||
access_token = account.get("accessToken")
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
if not _is_antigravity_token_expired(auth_data):
|
||||
return access_token
|
||||
|
||||
# Token is expired or near expiry — attempt a refresh
|
||||
refresh_token = account.get("refreshToken")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"Antigravity token expired and no refresh token available. "
|
||||
"Re-open the Antigravity IDE to refresh, or run 'antigravity-auth accounts add'."
|
||||
)
|
||||
return access_token # return stale token; proxy may still accept it briefly
|
||||
|
||||
logger.info("Antigravity token expired or near expiry, refreshing...")
|
||||
token_data = _refresh_antigravity_token(refresh_token)
|
||||
|
||||
if token_data and "access_token" in token_data:
|
||||
_save_refreshed_antigravity_credentials(auth_data, token_data)
|
||||
return token_data["access_token"]
|
||||
|
||||
logger.warning(
|
||||
"Antigravity token refresh failed. "
|
||||
"Re-open the Antigravity IDE or run 'antigravity-auth accounts add'."
|
||||
)
|
||||
return access_token
|
||||
|
||||
|
||||
def _is_antigravity_proxy_available() -> bool:
|
||||
"""Return True if antigravity-auth serve is running on localhost:8069."""
|
||||
import socket
|
||||
|
||||
try:
|
||||
with socket.create_connection(("localhost", 8069), timeout=0.5):
|
||||
return True
|
||||
except (OSError, TimeoutError):
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about an exported agent."""
|
||||
@@ -1141,7 +1454,10 @@ class AgentRunner:
|
||||
|
||||
# Create LLM provider
|
||||
# Uses LiteLLM which auto-detects the provider from model name
|
||||
if self.mock_mode:
|
||||
# Skip if already injected (e.g. worker agents with a pre-built LLM)
|
||||
if self._llm is not None:
|
||||
pass # LLM already configured externally
|
||||
elif self.mock_mode:
|
||||
# Use mock LLM for testing without real API calls
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
@@ -1155,6 +1471,7 @@ class AgentRunner:
|
||||
use_claude_code = llm_config.get("use_claude_code_subscription", False)
|
||||
use_codex = llm_config.get("use_codex_subscription", False)
|
||||
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
|
||||
use_antigravity = llm_config.get("use_antigravity_subscription", False)
|
||||
api_base = llm_config.get("api_base")
|
||||
|
||||
api_key = None
|
||||
@@ -1176,6 +1493,8 @@ class AgentRunner:
|
||||
if not api_key:
|
||||
print("Warning: Kimi Code subscription configured but no key found.")
|
||||
print("Run 'kimi /login' to authenticate, then try again.")
|
||||
elif use_antigravity:
|
||||
pass # AntigravityProvider handles credentials internally
|
||||
|
||||
if api_key and use_claude_code:
|
||||
# Use litellm's built-in Anthropic OAuth support.
|
||||
@@ -1214,6 +1533,19 @@ class AgentRunner:
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
elif use_antigravity:
|
||||
# Direct OAuth to Google's internal Cloud Code Assist gateway.
|
||||
# No local proxy required — AntigravityProvider handles token
|
||||
# refresh and Gemini-format request/response conversion natively.
|
||||
from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415
|
||||
|
||||
provider = AntigravityProvider(model=self.model)
|
||||
if not provider.has_credentials():
|
||||
print(
|
||||
"Warning: Antigravity credentials not found. "
|
||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
||||
)
|
||||
self._llm = provider
|
||||
else:
|
||||
# Local models (e.g. Ollama) don't need an API key
|
||||
if self._is_local_model(self.model):
|
||||
@@ -1340,7 +1672,7 @@ class AgentRunner:
|
||||
except Exception:
|
||||
pass # Best-effort — agent works without account info
|
||||
|
||||
# Skill configuration — the runtime handles discovery, loading, and
|
||||
# Skill configuration — the runtime handles discovery, loading, trust-gating and
|
||||
# prompt rasterization. The runner just builds the config.
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
@@ -1351,6 +1683,7 @@ class AgentRunner:
|
||||
skills=getattr(self, "_agent_skills", None),
|
||||
),
|
||||
project_root=self.agent_path,
|
||||
interactive=self._interactive,
|
||||
)
|
||||
|
||||
self._setup_agent_runtime(
|
||||
@@ -1381,6 +1714,8 @@ class AgentRunner:
|
||||
return "MISTRAL_API_KEY"
|
||||
elif model_lower.startswith("groq/"):
|
||||
return "GROQ_API_KEY"
|
||||
elif model_lower.startswith("openrouter/"):
|
||||
return "OPENROUTER_API_KEY"
|
||||
elif self._is_local_model(model_lower):
|
||||
return None # Local models don't need an API key
|
||||
elif model_lower.startswith("azure/"):
|
||||
@@ -1460,6 +1795,9 @@ class AgentRunner:
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus=None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
skills_manager_config=None,
|
||||
) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
|
||||
@@ -54,6 +54,8 @@ class ToolRegistry:
|
||||
def __init__(self):
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
self._mcp_clients: list[Any] = [] # List of MCPClient instances
|
||||
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
|
||||
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
|
||||
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
|
||||
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
|
||||
# MCP resync tracking
|
||||
@@ -243,6 +245,13 @@ class ToolRegistry:
|
||||
def _wrap_result(tool_use_id: str, result: Any) -> ToolResult:
|
||||
if isinstance(result, ToolResult):
|
||||
return result
|
||||
# MCP client returns dict with _images when image content is present
|
||||
if isinstance(result, dict) and "_images" in result:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use_id,
|
||||
content=result.get("_text", ""),
|
||||
image_content=result["_images"],
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use_id,
|
||||
content=json.dumps(result) if not isinstance(result, str) else result,
|
||||
@@ -480,6 +489,7 @@ class ToolRegistry:
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
use_connection_manager: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
Register an MCP server and discover its tools.
|
||||
@@ -495,12 +505,14 @@ class ToolRegistry:
|
||||
- url: Server URL (for http)
|
||||
- headers: HTTP headers (for http)
|
||||
- description: Server description (optional)
|
||||
use_connection_manager: When True, reuse a shared client keyed by server name
|
||||
|
||||
Returns:
|
||||
Number of tools registered from this server
|
||||
"""
|
||||
try:
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
# Build config object
|
||||
config = MCPServerConfig(
|
||||
@@ -516,11 +528,18 @@ class ToolRegistry:
|
||||
)
|
||||
|
||||
# Create and connect client
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
if use_connection_manager:
|
||||
client = MCPConnectionManager.get_instance().acquire(config)
|
||||
else:
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
|
||||
# Store client for cleanup
|
||||
self._mcp_clients.append(client)
|
||||
client_id = id(client)
|
||||
self._mcp_client_servers[client_id] = config.name
|
||||
if use_connection_manager:
|
||||
self._mcp_managed_clients.add(client_id)
|
||||
|
||||
# Register each tool
|
||||
server_name = server_config["name"]
|
||||
@@ -560,7 +579,9 @@ class ToolRegistry:
|
||||
}
|
||||
merged_inputs = {**clean_inputs, **filtered_context}
|
||||
result = client_ref.call_tool(tool_name, merged_inputs)
|
||||
# MCP tools return content array, extract the result
|
||||
# MCP client already extracts content (returns str
|
||||
# or {"_text": ..., "_images": ...} for image results).
|
||||
# Handle legacy list format from HTTP transport.
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
if isinstance(result[0], dict) and "text" in result[0]:
|
||||
return result[0]["text"]
|
||||
@@ -720,12 +741,7 @@ class ToolRegistry:
|
||||
logger.info("%s — resyncing MCP servers", reason)
|
||||
|
||||
# 1. Disconnect existing MCP clients
|
||||
for client in self._mcp_clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client during resync: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._cleanup_mcp_clients("during resync")
|
||||
|
||||
# 2. Remove MCP-registered tools
|
||||
for name in self._mcp_tool_names:
|
||||
@@ -740,12 +756,28 @@ class ToolRegistry:
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all MCP client connections."""
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
def _cleanup_mcp_clients(self, context: str = "") -> None:
|
||||
"""Disconnect or release all tracked MCP clients for this registry."""
|
||||
if context:
|
||||
context = f" {context}"
|
||||
|
||||
for client in self._mcp_clients:
|
||||
client_id = id(client)
|
||||
server_name = self._mcp_client_servers.get(client_id, client.config.name)
|
||||
try:
|
||||
client.disconnect()
|
||||
if client_id in self._mcp_managed_clients:
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
MCPConnectionManager.get_instance().release(server_name)
|
||||
else:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client: {e}")
|
||||
logger.warning(f"Error disconnecting MCP client{context}: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._mcp_client_servers.clear()
|
||||
self._mcp_managed_clients.clear()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup."""
|
||||
|
||||
@@ -137,6 +137,7 @@ class AgentRuntime:
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent runtime.
|
||||
@@ -158,6 +159,9 @@ class AgentRuntime:
|
||||
event_bus: Optional external EventBus. If provided, the runtime shares
|
||||
this bus instead of creating its own. Used by SessionManager to
|
||||
share a single bus between queen, worker, and judge.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
@@ -195,6 +199,8 @@ class AgentRuntime:
|
||||
self._skills_manager = SkillsManager()
|
||||
self._skills_manager.load()
|
||||
|
||||
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
|
||||
|
||||
# Primary graph identity
|
||||
self._graph_id: str = graph_id or "primary"
|
||||
|
||||
@@ -341,6 +347,7 @@ class AgentRuntime:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
@@ -977,6 +984,7 @@ class AgentRuntime:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
if self._running:
|
||||
await stream.start()
|
||||
@@ -1466,6 +1474,7 @@ class AgentRuntime:
|
||||
graph_id: str | None = None,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Inject user input into a running client-facing node.
|
||||
|
||||
@@ -1478,6 +1487,8 @@ class AgentRuntime:
|
||||
graph_id: Optional graph to search first (defaults to active graph)
|
||||
is_client_input: True when the message originates from a real
|
||||
human user (e.g. /chat endpoint), False for external events.
|
||||
image_content: Optional list of image content blocks (OpenAI
|
||||
image_url format) to include alongside the text.
|
||||
|
||||
Returns:
|
||||
True if input was delivered, False if no matching node found
|
||||
@@ -1489,7 +1500,9 @@ class AgentRuntime:
|
||||
target = graph_id or self._active_graph_id
|
||||
if target in self._graphs:
|
||||
for stream in self._graphs[target].streams.values():
|
||||
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
|
||||
if await stream.inject_input(
|
||||
node_id, content, is_client_input=is_client_input, image_content=image_content
|
||||
):
|
||||
return True
|
||||
|
||||
# Then search all other graphs
|
||||
@@ -1497,7 +1510,9 @@ class AgentRuntime:
|
||||
if gid == target:
|
||||
continue
|
||||
for stream in reg.streams.values():
|
||||
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
|
||||
if await stream.inject_input(
|
||||
node_id, content, is_client_input=is_client_input, image_content=image_content
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1760,6 +1775,7 @@ def create_agent_runtime(
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
) -> AgentRuntime:
|
||||
"""
|
||||
Create and configure an AgentRuntime with entry points.
|
||||
@@ -1786,6 +1802,9 @@ def create_agent_runtime(
|
||||
accounts_data: Raw account data for per-node prompt generation.
|
||||
tool_provider_map: Tool name to provider name mapping for account routing.
|
||||
event_bus: Optional external EventBus to share with other components.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt.
|
||||
protocols_prompt: Default skill operational protocols for system prompt.
|
||||
skill_dirs: Skill base directories for Tier 3 resource access.
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
@@ -1819,6 +1838,7 @@ def create_agent_runtime(
|
||||
skills_manager_config=skills_manager_config,
|
||||
skills_catalog_prompt=skills_catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
skill_dirs=skill_dirs,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
|
||||
@@ -117,6 +117,7 @@ class EventType(StrEnum):
|
||||
|
||||
# Context management
|
||||
CONTEXT_COMPACTED = "context_compacted"
|
||||
CONTEXT_USAGE_UPDATED = "context_usage_updated"
|
||||
|
||||
# External triggers
|
||||
WEBHOOK_RECEIVED = "webhook_received"
|
||||
@@ -159,6 +160,7 @@ class EventType(StrEnum):
|
||||
TRIGGER_DEACTIVATED = "trigger_deactivated"
|
||||
TRIGGER_FIRED = "trigger_fired"
|
||||
TRIGGER_REMOVED = "trigger_removed"
|
||||
TRIGGER_UPDATED = "trigger_updated"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -188,6 +188,7 @@ class ExecutionStream:
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -213,6 +214,7 @@ class ExecutionStream:
|
||||
tool_provider_map: Tool name to provider name mapping for account routing
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
@@ -236,6 +238,7 @@ class ExecutionStream:
|
||||
self._tool_provider_map = tool_provider_map
|
||||
self._skills_catalog_prompt = skills_catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
self._skill_dirs: list[str] = skill_dirs or []
|
||||
|
||||
_es_logger = logging.getLogger(__name__)
|
||||
if protocols_prompt:
|
||||
@@ -430,6 +433,7 @@ class ExecutionStream:
|
||||
content: str,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Inject user input into a running client-facing EventLoopNode.
|
||||
|
||||
@@ -441,7 +445,9 @@ class ExecutionStream:
|
||||
for executor in self._active_executors.values():
|
||||
node = executor.node_registry.get(node_id)
|
||||
if node is not None and hasattr(node, "inject_event"):
|
||||
await node.inject_event(content, is_client_input=is_client_input)
|
||||
await node.inject_event(
|
||||
content, is_client_input=is_client_input, image_content=image_content
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -696,6 +702,7 @@ class ExecutionStream:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self._skills_catalog_prompt,
|
||||
protocols_prompt=self._protocols_prompt,
|
||||
skill_dirs=self._skill_dirs,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
@@ -8,6 +8,7 @@ write. Errors are silently swallowed — this must never break the agent.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
@@ -47,6 +48,9 @@ def log_llm_turn(
|
||||
Never raises.
|
||||
"""
|
||||
try:
|
||||
# Skip logging during test runs to avoid polluting real logs.
|
||||
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("HIVE_DISABLE_LLM_LOGS"):
|
||||
return
|
||||
global _log_file, _log_ready # noqa: PLW0603
|
||||
if not _log_ready:
|
||||
_log_file = _open_log()
|
||||
|
||||
@@ -69,6 +69,7 @@ async def create_queen(
|
||||
QueenPhaseState,
|
||||
register_queen_lifecycle_tools,
|
||||
)
|
||||
from framework.tools.queen_memory_tools import register_queen_memory_tools
|
||||
|
||||
hive_home = Path.home() / ".hive"
|
||||
|
||||
@@ -122,6 +123,9 @@ async def create_queen(
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
# ---- Episodic memory tools (always registered) ---------------------
|
||||
register_queen_memory_tools(queen_registry)
|
||||
|
||||
# ---- Monitoring tools (only when worker is loaded) ----------------
|
||||
if session.worker_runtime:
|
||||
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
|
||||
|
||||
@@ -37,6 +37,7 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.NODE_RETRY,
|
||||
EventType.NODE_TOOL_DOOM_LOOP,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.CONTEXT_USAGE_UPDATED,
|
||||
EventType.WORKER_LOADED,
|
||||
EventType.CREDENTIALS_REQUIRED,
|
||||
EventType.SUBAGENT_REPORT,
|
||||
@@ -46,6 +47,7 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.TRIGGER_DEACTIVATED,
|
||||
EventType.TRIGGER_FIRED,
|
||||
EventType.TRIGGER_REMOVED,
|
||||
EventType.TRIGGER_UPDATED,
|
||||
EventType.DRAFT_GRAPH_UPDATED,
|
||||
]
|
||||
|
||||
|
||||
@@ -108,7 +108,10 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
The input box is permanently connected to the queen agent.
|
||||
Worker input is handled separately via /worker-input.
|
||||
|
||||
Body: {"message": "hello"}
|
||||
Body: {"message": "hello", "images": [{"type": "image_url", "image_url": {"url": "data:..."}}]}
|
||||
|
||||
The optional ``images`` field accepts a list of OpenAI-format image_url
|
||||
content blocks. The frontend encodes images as base64 data URIs.
|
||||
"""
|
||||
session, err = resolve_session(request)
|
||||
if err:
|
||||
@@ -116,15 +119,16 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
|
||||
body = await request.json()
|
||||
message = body.get("message", "")
|
||||
image_content = body.get("images") or None # list[dict] | None
|
||||
|
||||
if not message:
|
||||
if not message and not image_content:
|
||||
return web.json_response({"error": "message is required"}, status=400)
|
||||
|
||||
queen_executor = session.queen_executor
|
||||
if queen_executor is not None:
|
||||
node = queen_executor.node_registry.get("queen")
|
||||
if node is not None and hasattr(node, "inject_event"):
|
||||
await node.inject_event(message, is_client_input=True)
|
||||
await node.inject_event(message, is_client_input=True, image_content=image_content)
|
||||
# Publish to EventBus so the session event log captures user messages
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
@@ -134,7 +138,10 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
stream_id="queen",
|
||||
node_id="queen",
|
||||
execution_id=session.id,
|
||||
data={"content": message},
|
||||
data={
|
||||
"content": message,
|
||||
"image_count": len(image_content) if image_content else 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
return web.json_response(
|
||||
|
||||
@@ -11,7 +11,6 @@ Session-primary routes:
|
||||
- GET /api/sessions/{session_id}/entry-points — list entry points
|
||||
- PATCH /api/sessions/{session_id}/triggers/{id} — update trigger task
|
||||
- GET /api/sessions/{session_id}/graphs — list graph IDs
|
||||
- GET /api/sessions/{session_id}/queen-messages — queen conversation history
|
||||
- GET /api/sessions/{session_id}/events/history — persisted eventbus log (for replay)
|
||||
|
||||
Worker session browsing (persisted execution runs on disk):
|
||||
@@ -24,9 +23,13 @@ Worker session browsing (persisted execution runs on disk):
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -50,8 +53,11 @@ def _get_manager(request: web.Request) -> SessionManager:
|
||||
|
||||
def _session_to_live_dict(session) -> dict:
|
||||
"""Serialize a live Session to the session-primary JSON shape."""
|
||||
from framework.llm.capabilities import supports_image_tool_results
|
||||
|
||||
info = session.worker_info
|
||||
phase_state = getattr(session, "phase_state", None)
|
||||
queen_model: str = getattr(getattr(session, "runner", None), "model", "") or ""
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"worker_id": session.worker_id,
|
||||
@@ -67,6 +73,7 @@ def _session_to_live_dict(session) -> dict:
|
||||
"queen_phase": phase_state.phase
|
||||
if phase_state
|
||||
else ("staging" if session.worker_runtime else "planning"),
|
||||
"queen_supports_images": supports_image_tool_results(queen_model) if queen_model else True,
|
||||
}
|
||||
|
||||
|
||||
@@ -408,7 +415,7 @@ async def handle_session_entry_points(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
async def handle_update_trigger_task(request: web.Request) -> web.Response:
|
||||
"""PATCH /api/sessions/{session_id}/triggers/{trigger_id} — update trigger task."""
|
||||
"""PATCH /api/sessions/{session_id}/triggers/{trigger_id} — update trigger fields."""
|
||||
session, err = resolve_session(request)
|
||||
if err:
|
||||
return err
|
||||
@@ -427,30 +434,136 @@ async def handle_update_trigger_task(request: web.Request) -> web.Response:
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
|
||||
task = body.get("task")
|
||||
if task is None:
|
||||
return web.json_response({"error": "Missing 'task' field"}, status=400)
|
||||
if not isinstance(task, str):
|
||||
return web.json_response({"error": "'task' must be a string"}, status=400)
|
||||
updates: dict[str, object] = {}
|
||||
|
||||
tdef.task = task
|
||||
if "task" in body:
|
||||
task = body.get("task")
|
||||
if not isinstance(task, str):
|
||||
return web.json_response({"error": "'task' must be a string"}, status=400)
|
||||
tdef.task = task
|
||||
updates["task"] = tdef.task
|
||||
|
||||
trigger_config_update = body.get("trigger_config")
|
||||
if trigger_config_update is not None:
|
||||
if not isinstance(trigger_config_update, dict):
|
||||
return web.json_response(
|
||||
{"error": "'trigger_config' must be an object"},
|
||||
status=400,
|
||||
)
|
||||
merged_trigger_config = dict(tdef.trigger_config)
|
||||
merged_trigger_config.update(trigger_config_update)
|
||||
|
||||
if tdef.trigger_type == "timer":
|
||||
cron_expr = merged_trigger_config.get("cron")
|
||||
interval = merged_trigger_config.get("interval_minutes")
|
||||
if cron_expr is not None and not isinstance(cron_expr, str):
|
||||
return web.json_response(
|
||||
{"error": "'trigger_config.cron' must be a string"},
|
||||
status=400,
|
||||
)
|
||||
if cron_expr:
|
||||
try:
|
||||
from croniter import croniter
|
||||
|
||||
if not croniter.is_valid(cron_expr):
|
||||
return web.json_response(
|
||||
{"error": f"Invalid cron expression: {cron_expr}"},
|
||||
status=400,
|
||||
)
|
||||
except ImportError:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": (
|
||||
"croniter package not installed — cannot validate cron expression."
|
||||
)
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
merged_trigger_config.pop("interval_minutes", None)
|
||||
elif interval is None:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": (
|
||||
"Timer trigger needs 'cron' or 'interval_minutes' in trigger_config."
|
||||
)
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
elif not isinstance(interval, (int, float)) or interval <= 0:
|
||||
return web.json_response(
|
||||
{"error": "'trigger_config.interval_minutes' must be > 0"},
|
||||
status=400,
|
||||
)
|
||||
tdef.trigger_config = merged_trigger_config
|
||||
updates["trigger_config"] = tdef.trigger_config
|
||||
|
||||
if not updates:
|
||||
return web.json_response(
|
||||
{"error": "Provide at least one of 'task' or 'trigger_config'"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Persist to session state and agent definition
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_persist_active_triggers,
|
||||
_save_trigger_to_agent,
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
if "trigger_config" in updates and trigger_id in getattr(session, "active_trigger_ids", set()):
|
||||
task = session.active_timer_tasks.pop(trigger_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
getattr(session, "trigger_next_fire", {}).pop(trigger_id, None)
|
||||
|
||||
webhook_subs = getattr(session, "active_webhook_subs", {})
|
||||
if sub_id := webhook_subs.pop(trigger_id, None):
|
||||
with contextlib.suppress(Exception):
|
||||
session.event_bus.unsubscribe(sub_id)
|
||||
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, trigger_id, tdef)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, trigger_id, tdef)
|
||||
|
||||
if trigger_id in getattr(session, "active_trigger_ids", set()):
|
||||
session_id = request.match_info["session_id"]
|
||||
await _persist_active_triggers(session, session_id)
|
||||
|
||||
_save_trigger_to_agent(session, trigger_id, tdef)
|
||||
|
||||
# Emit SSE event so the frontend updates the graph and detail panel
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_UPDATED,
|
||||
stream_id="queen",
|
||||
data={
|
||||
"trigger_id": trigger_id,
|
||||
"task": tdef.task,
|
||||
"trigger_config": tdef.trigger_config,
|
||||
"trigger_type": tdef.trigger_type,
|
||||
"name": tdef.description or trigger_id,
|
||||
"entry_node": getattr(
|
||||
getattr(getattr(session, "runner", None), "graph", None),
|
||||
"entry_node",
|
||||
None,
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"trigger_id": trigger_id,
|
||||
"task": tdef.task,
|
||||
"trigger_config": tdef.trigger_config,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -754,60 +867,6 @@ async def handle_messages(request: web.Request) -> web.Response:
|
||||
return web.json_response({"messages": all_messages})
|
||||
|
||||
|
||||
async def handle_queen_messages(request: web.Request) -> web.Response:
|
||||
"""GET /api/sessions/{session_id}/queen-messages — get queen conversation.
|
||||
|
||||
Reads directly from disk so it works for both live sessions and cold
|
||||
(post-server-restart) sessions — no live session required.
|
||||
"""
|
||||
session_id = request.match_info["session_id"]
|
||||
|
||||
queen_dir = Path.home() / ".hive" / "queen" / "session" / session_id
|
||||
convs_dir = queen_dir / "conversations"
|
||||
if not convs_dir.exists():
|
||||
return web.json_response({"messages": [], "session_id": session_id})
|
||||
|
||||
all_messages: list[dict] = []
|
||||
|
||||
def _read_parts(parts_dir: Path, node_id: str) -> None:
|
||||
if not parts_dir.exists():
|
||||
return
|
||||
for part_file in sorted(parts_dir.iterdir()):
|
||||
if part_file.suffix != ".json":
|
||||
continue
|
||||
try:
|
||||
part = json.loads(part_file.read_text(encoding="utf-8"))
|
||||
part["_node_id"] = node_id
|
||||
# Use file mtime as created_at so frontend can order
|
||||
# queen and worker messages chronologically.
|
||||
part.setdefault("created_at", part_file.stat().st_mtime)
|
||||
all_messages.append(part)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
# Flat layout: conversations/parts/*.json
|
||||
_read_parts(convs_dir / "parts", "queen")
|
||||
|
||||
# Node-based layout: conversations/<node_id>/parts/*.json
|
||||
for node_dir in convs_dir.iterdir():
|
||||
if not node_dir.is_dir() or node_dir.name == "parts":
|
||||
continue
|
||||
_read_parts(node_dir / "parts", node_dir.name)
|
||||
|
||||
all_messages.sort(key=lambda m: m.get("created_at", m.get("seq", 0)))
|
||||
|
||||
# Filter to client-facing messages only
|
||||
all_messages = [
|
||||
m
|
||||
for m in all_messages
|
||||
if not m.get("is_transition_marker")
|
||||
and m["role"] != "tool"
|
||||
and not (m["role"] == "assistant" and m.get("tool_calls"))
|
||||
]
|
||||
|
||||
return web.json_response({"messages": all_messages, "session_id": session_id})
|
||||
|
||||
|
||||
async def handle_session_events_history(request: web.Request) -> web.Response:
|
||||
"""GET /api/sessions/{session_id}/events/history — persisted eventbus log.
|
||||
|
||||
@@ -925,6 +984,29 @@ async def handle_discover(request: web.Request) -> web.Response:
|
||||
return web.json_response(result)
|
||||
|
||||
|
||||
async def handle_reveal_session_folder(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/reveal — open session data folder in the OS file manager."""
|
||||
manager: SessionManager = request.app["manager"]
|
||||
session_id = request.match_info["session_id"]
|
||||
|
||||
session = manager.get_session(session_id)
|
||||
storage_session_id = (session.queen_resume_from or session.id) if session else session_id
|
||||
folder = Path.home() / ".hive" / "queen" / "session" / storage_session_id
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.Popen(["open", str(folder)])
|
||||
elif sys.platform == "win32":
|
||||
subprocess.Popen(["explorer", str(folder)])
|
||||
else:
|
||||
subprocess.Popen(["xdg-open", str(folder)])
|
||||
except Exception as exc:
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
return web.json_response({"path": str(folder)})
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Route registration
|
||||
# ------------------------------------------------------------------
|
||||
@@ -949,13 +1031,14 @@ def register_routes(app: web.Application) -> None:
|
||||
app.router.add_delete("/api/sessions/{session_id}/worker", handle_unload_worker)
|
||||
|
||||
# Session info
|
||||
app.router.add_post("/api/sessions/{session_id}/reveal", handle_reveal_session_folder)
|
||||
app.router.add_get("/api/sessions/{session_id}/stats", handle_session_stats)
|
||||
app.router.add_get("/api/sessions/{session_id}/entry-points", handle_session_entry_points)
|
||||
app.router.add_patch(
|
||||
"/api/sessions/{session_id}/triggers/{trigger_id}", handle_update_trigger_task
|
||||
)
|
||||
app.router.add_get("/api/sessions/{session_id}/graphs", handle_session_graphs)
|
||||
app.router.add_get("/api/sessions/{session_id}/queen-messages", handle_queen_messages)
|
||||
|
||||
app.router.add_get("/api/sessions/{session_id}/events/history", handle_session_events_history)
|
||||
|
||||
# Worker session browsing (session-primary)
|
||||
|
||||
@@ -47,6 +47,8 @@ class Session:
|
||||
worker_handoff_sub: str | None = None
|
||||
# Memory consolidation subscription (fires on CONTEXT_COMPACTED)
|
||||
memory_consolidation_sub: str | None = None
|
||||
# Worker run digest subscription (fires on EXECUTION_COMPLETED / EXECUTION_FAILED)
|
||||
worker_digest_sub: str | None = None
|
||||
# Trigger definitions loaded from agent's triggers.json (available but inactive)
|
||||
available_triggers: dict[str, TriggerDefinition] = field(default_factory=dict)
|
||||
# Active trigger tracking (IDs currently firing + their asyncio tasks)
|
||||
@@ -94,8 +96,7 @@ class SessionManager:
|
||||
|
||||
Internal helper — use create_session() or create_session_with_worker().
|
||||
"""
|
||||
from framework.config import RuntimeConfig
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.config import RuntimeConfig, get_hive_config
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -109,12 +110,20 @@ class SessionManager:
|
||||
rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model)
|
||||
|
||||
# Session owns these — shared with queen and worker
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
llm_config = get_hive_config().get("llm", {})
|
||||
if llm_config.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
llm = AntigravityProvider(model=rc.model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
event_bus = EventBus()
|
||||
|
||||
session = Session(
|
||||
@@ -177,6 +186,31 @@ class SessionManager:
|
||||
agent_path = Path(agent_path)
|
||||
resolved_worker_id = agent_id or agent_path.name
|
||||
|
||||
# When cold-restoring, check meta.json for the phase — if the agent
|
||||
# was still being built we must NOT try to load the worker (the code
|
||||
# is incomplete and will fail to import).
|
||||
if queen_resume_from:
|
||||
_resume_phase = None
|
||||
_meta_path = (
|
||||
Path.home() / ".hive" / "queen" / "session" / queen_resume_from / "meta.json"
|
||||
)
|
||||
if _meta_path.exists():
|
||||
try:
|
||||
_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
|
||||
_resume_phase = _meta.get("phase")
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
if _resume_phase in ("building", "planning"):
|
||||
# Fall back to queen-only session — cold resume handler in
|
||||
# _start_queen will set phase_state.agent_path and switch to
|
||||
# the correct phase.
|
||||
return await self.create_session(
|
||||
session_id=session_id,
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
)
|
||||
|
||||
# Reuse the original session ID when cold-restoring so the frontend
|
||||
# sees one continuous session instead of a new one each time.
|
||||
session = await self._create_session_core(
|
||||
@@ -193,6 +227,9 @@ class SessionManager:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Restore active triggers from persisted state (cold restore)
|
||||
await self._restore_active_triggers(session, session.id)
|
||||
|
||||
# Start queen with worker profile + lifecycle + monitoring tools
|
||||
worker_identity = (
|
||||
build_worker_profile(session.worker_runtime, agent_path=agent_path)
|
||||
@@ -204,7 +241,23 @@ class SessionManager:
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# If anything fails, tear down the session
|
||||
if queen_resume_from:
|
||||
# Cold restore: worker load failed (e.g. incomplete code from a
|
||||
# building session). Fall back to queen-only so the user can
|
||||
# continue the conversation and fix / rebuild the agent.
|
||||
logger.warning(
|
||||
"Cold restore: worker load failed for '%s', falling back to queen-only",
|
||||
agent_path,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.stop_session(session.id)
|
||||
return await self.create_session(
|
||||
session_id=session_id,
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
)
|
||||
# If anything fails (non-cold-restore), tear down the session
|
||||
await self.stop_session(session.id)
|
||||
raise
|
||||
return session
|
||||
@@ -241,7 +294,17 @@ class SessionManager:
|
||||
try:
|
||||
# Blocking I/O — load in executor
|
||||
loop = asyncio.get_running_loop()
|
||||
resolved_model = model or self._model
|
||||
|
||||
# Prioritize: explicit model arg > worker-specific model > session default
|
||||
from framework.config import (
|
||||
get_preferred_worker_model,
|
||||
get_worker_api_base,
|
||||
get_worker_api_key,
|
||||
get_worker_llm_extra_kwargs,
|
||||
)
|
||||
|
||||
worker_model = get_preferred_worker_model()
|
||||
resolved_model = model or worker_model or self._model
|
||||
runner = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: AgentRunner.load(
|
||||
@@ -253,6 +316,30 @@ class SessionManager:
|
||||
),
|
||||
)
|
||||
|
||||
# If a worker-specific model is configured, build an LLM provider
|
||||
# with the correct worker credentials so _setup() doesn't fall back
|
||||
# to the queen's llm config (which may be a different provider).
|
||||
if worker_model and not model:
|
||||
from framework.config import get_hive_config
|
||||
|
||||
worker_llm_cfg = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm_cfg.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
runner._llm = AntigravityProvider(model=resolved_model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
worker_api_key = get_worker_api_key()
|
||||
worker_api_base = get_worker_api_base()
|
||||
worker_extra = get_worker_llm_extra_kwargs()
|
||||
runner._llm = LiteLLMProvider(
|
||||
model=resolved_model,
|
||||
api_key=worker_api_key,
|
||||
api_base=worker_api_base,
|
||||
**worker_extra,
|
||||
)
|
||||
|
||||
# Setup with session's event bus
|
||||
if runner._agent_runtime is None:
|
||||
await loop.run_in_executor(
|
||||
@@ -297,6 +384,9 @@ class SessionManager:
|
||||
session.worker_runtime = runtime
|
||||
session.worker_info = info
|
||||
|
||||
# Subscribe to execution completion for per-run digest generation
|
||||
self._subscribe_worker_digest(session)
|
||||
|
||||
async with self._lock:
|
||||
self._loading.discard(session.id)
|
||||
|
||||
@@ -399,6 +489,51 @@ class SessionManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _restore_active_triggers(self, session: "Session", session_id: str) -> None:
|
||||
"""Restore previously active triggers from persisted session state.
|
||||
|
||||
Called after worker loading to restart any timer/webhook triggers
|
||||
that were active before a server restart.
|
||||
"""
|
||||
if not session.available_triggers or not session.worker_runtime:
|
||||
return
|
||||
try:
|
||||
store = session.worker_runtime._session_store
|
||||
state = await store.read_state(session_id)
|
||||
if state and state.active_triggers:
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
|
||||
for tid in state.active_triggers:
|
||||
tdef = session.available_triggers.get(tid)
|
||||
if tdef:
|
||||
# Restore user-configured task override
|
||||
saved_task = saved_tasks.get(tid, "")
|
||||
if saved_task:
|
||||
tdef.task = saved_task
|
||||
tdef.active = True
|
||||
session.active_trigger_ids.add(tid)
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, tid, tdef)
|
||||
logger.info("Restored trigger timer '%s'", tid)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, tid, tdef)
|
||||
logger.info("Restored webhook trigger '%s'", tid)
|
||||
else:
|
||||
logger.warning(
|
||||
"Saved trigger '%s' not found in worker entry points, skipping",
|
||||
tid,
|
||||
)
|
||||
|
||||
# Restore worker_configured flag
|
||||
if state and getattr(state, "worker_configured", False):
|
||||
session.worker_configured = True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore active triggers: %s", e)
|
||||
|
||||
async def load_worker(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -447,44 +582,7 @@ class SessionManager:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Restore previously active triggers from persisted session state
|
||||
if session.available_triggers and session.worker_runtime:
|
||||
try:
|
||||
store = session.worker_runtime._session_store
|
||||
state = await store.read_state(session_id)
|
||||
if state and state.active_triggers:
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
|
||||
for tid in state.active_triggers:
|
||||
tdef = session.available_triggers.get(tid)
|
||||
if tdef:
|
||||
# Restore user-configured task override
|
||||
saved_task = saved_tasks.get(tid, "")
|
||||
if saved_task:
|
||||
tdef.task = saved_task
|
||||
tdef.active = True
|
||||
session.active_trigger_ids.add(tid)
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, tid, tdef)
|
||||
logger.info("Restored trigger timer '%s'", tid)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, tid, tdef)
|
||||
logger.info("Restored webhook trigger '%s'", tid)
|
||||
else:
|
||||
logger.warning(
|
||||
"Saved trigger '%s' not found in worker entry points, skipping",
|
||||
tid,
|
||||
)
|
||||
|
||||
# Restore worker_configured flag
|
||||
if state and getattr(state, "worker_configured", False):
|
||||
session.worker_configured = True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore active triggers: %s", e)
|
||||
await self._restore_active_triggers(session, session_id)
|
||||
|
||||
# Emit SSE event so the frontend can update UI
|
||||
await self._emit_worker_loaded(session)
|
||||
@@ -526,6 +624,13 @@ class SessionManager:
|
||||
await self._emit_trigger_events(session, "removed", session.available_triggers)
|
||||
session.available_triggers.clear()
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
worker_id = session.worker_id
|
||||
session.worker_id = None
|
||||
session.worker_path = None
|
||||
@@ -563,6 +668,13 @@ class SessionManager:
|
||||
pass
|
||||
session.worker_handoff_sub = None
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
# Stop queen and memory consolidation subscription
|
||||
if session.memory_consolidation_sub is not None:
|
||||
try:
|
||||
@@ -647,6 +759,135 @@ class SessionManager:
|
||||
else:
|
||||
logger.warning("Worker handoff received but queen node not ready")
|
||||
|
||||
def _subscribe_worker_digest(self, session: Session) -> None:
|
||||
"""Subscribe to worker events to write per-run digests.
|
||||
|
||||
Three triggers:
|
||||
- NODE_LOOP_ITERATION: write a mid-run snapshot, throttled to at most
|
||||
once every _DIGEST_COOLDOWN seconds per execution.
|
||||
- TOOL_CALL_COMPLETED for delegate_to_sub_agent: same throttled snapshot.
|
||||
Orchestrator nodes often run all subagent calls in a single LLM turn,
|
||||
so NODE_LOOP_ITERATION only fires once at the end. Subagent
|
||||
completions provide intermediate checkpoints.
|
||||
- EXECUTION_COMPLETED / EXECUTION_FAILED: always write the final digest,
|
||||
bypassing the cooldown.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
_DIGEST_COOLDOWN = 300.0 # seconds between mid-run snapshots
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
agent_name = session.worker_path.name if session.worker_path else None
|
||||
if not agent_name:
|
||||
return
|
||||
|
||||
_agent_name = agent_name
|
||||
_llm = session.llm
|
||||
_bus = session.event_bus
|
||||
# per-execution_id monotonic timestamp of last mid-run digest
|
||||
_last_digest: dict[str, float] = {}
|
||||
|
||||
def _resolve_run_id(exec_id: str) -> str | None:
|
||||
"""Look up the run_id for a given execution_id via EXECUTION_STARTED history."""
|
||||
for e in _bus.get_history(event_type=_ET.EXECUTION_STARTED, limit=200):
|
||||
if e.execution_id == exec_id and getattr(e, "run_id", None):
|
||||
return e.run_id
|
||||
return None
|
||||
|
||||
async def _inject_digest_to_queen(run_id: str) -> None:
|
||||
"""Read the written digest and push it into the queen's conversation."""
|
||||
from framework.agents.worker_memory import digest_path
|
||||
|
||||
try:
|
||||
content = digest_path(_agent_name, run_id).read_text(encoding="utf-8").strip()
|
||||
except OSError:
|
||||
return
|
||||
if not content:
|
||||
return
|
||||
executor = session.queen_executor
|
||||
if executor is None:
|
||||
return
|
||||
node = executor.node_registry.get("queen")
|
||||
if node is None or not hasattr(node, "inject_event"):
|
||||
return
|
||||
await node.inject_event(f"[WORKER_DIGEST]\n{content}")
|
||||
|
||||
async def _consolidate_and_notify(run_id: str, outcome_event: Any) -> None:
|
||||
"""Write the digest then push it to the queen."""
|
||||
from framework.agents.worker_memory import consolidate_worker_run
|
||||
|
||||
await consolidate_worker_run(_agent_name, run_id, outcome_event, _bus, _llm)
|
||||
await _inject_digest_to_queen(run_id)
|
||||
|
||||
async def _on_worker_event(event: Any) -> None:
|
||||
if event.stream_id == "queen":
|
||||
return
|
||||
|
||||
exec_id = event.execution_id
|
||||
|
||||
if event.type == _ET.EXECUTION_STARTED:
|
||||
# New run on this execution_id — start the cooldown timer so
|
||||
# mid-run snapshots don't fire immediately at session start.
|
||||
# The first snapshot will happen after _DIGEST_COOLDOWN seconds.
|
||||
if exec_id:
|
||||
_last_digest[exec_id] = _time.monotonic()
|
||||
|
||||
elif event.type in (
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
_ET.EXECUTION_FAILED,
|
||||
_ET.EXECUTION_PAUSED,
|
||||
):
|
||||
# Final digest — always fire, ignore cooldown.
|
||||
# EXECUTION_PAUSED covers cancellation (queen re-triggering the
|
||||
# worker cancels the previous execution, emitting paused).
|
||||
run_id = getattr(event, "run_id", None) or _resolve_run_id(exec_id)
|
||||
if run_id:
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, event),
|
||||
name=f"worker-digest-final-{run_id}",
|
||||
)
|
||||
|
||||
elif event.type in (_ET.NODE_LOOP_ITERATION, _ET.TOOL_CALL_COMPLETED):
|
||||
# Mid-run snapshot — respect 300 s cooldown per execution.
|
||||
# TOOL_CALL_COMPLETED is only interesting for subagent calls;
|
||||
# regular tool completions are too frequent and too cheap.
|
||||
if event.type == _ET.TOOL_CALL_COMPLETED:
|
||||
tool_name = (event.data or {}).get("tool_name", "")
|
||||
if tool_name != "delegate_to_sub_agent":
|
||||
return
|
||||
if not exec_id:
|
||||
return
|
||||
now = _time.monotonic()
|
||||
if now - _last_digest.get(exec_id, 0.0) < _DIGEST_COOLDOWN:
|
||||
return
|
||||
run_id = _resolve_run_id(exec_id)
|
||||
if run_id:
|
||||
_last_digest[exec_id] = now
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, None),
|
||||
name=f"worker-digest-{run_id}",
|
||||
)
|
||||
|
||||
session.worker_digest_sub = session.event_bus.subscribe(
|
||||
event_types=[
|
||||
_ET.EXECUTION_STARTED,
|
||||
_ET.NODE_LOOP_ITERATION,
|
||||
_ET.TOOL_CALL_COMPLETED,
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
_ET.EXECUTION_FAILED,
|
||||
_ET.EXECUTION_PAUSED,
|
||||
],
|
||||
handler=_on_worker_event,
|
||||
)
|
||||
|
||||
def _subscribe_worker_handoffs(self, session: Session, executor: Any) -> None:
|
||||
"""Subscribe queen to worker/subagent escalation handoff events."""
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
@@ -700,16 +941,21 @@ class SessionManager:
|
||||
else None
|
||||
)
|
||||
)
|
||||
_meta_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agent_name": _agent_name,
|
||||
"agent_path": str(session.worker_path) if session.worker_path else None,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
# Merge into existing meta.json to preserve fields written by
|
||||
# _update_meta_json (e.g. phase, agent_path set during building).
|
||||
_existing_meta: dict = {}
|
||||
if _meta_path.exists():
|
||||
try:
|
||||
_existing_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
_new_meta: dict = {"created_at": time.time()}
|
||||
if _agent_name is not None:
|
||||
_new_meta["agent_name"] = _agent_name
|
||||
if session.worker_path is not None:
|
||||
_new_meta["agent_path"] = str(session.worker_path)
|
||||
_existing_meta.update(_new_meta)
|
||||
_meta_path.write_text(json.dumps(_existing_meta), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -719,6 +965,7 @@ class SessionManager:
|
||||
# then use max+1 as offset so resumed sessions produce monotonically
|
||||
# increasing iteration values — preventing frontend message ID collisions.
|
||||
iteration_offset = 0
|
||||
last_phase = ""
|
||||
events_path = queen_dir / "events.jsonl"
|
||||
try:
|
||||
if events_path.exists():
|
||||
@@ -730,17 +977,25 @@ class SessionManager:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
it = evt.get("data", {}).get("iteration")
|
||||
data = evt.get("data", {})
|
||||
it = data.get("iteration")
|
||||
if isinstance(it, int) and it > max_iter:
|
||||
max_iter = it
|
||||
# Track the latest queen phase from QUEEN_PHASE_CHANGED events
|
||||
if evt.get("type") == "queen_phase_changed":
|
||||
phase = data.get("phase")
|
||||
if phase:
|
||||
last_phase = phase
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
if max_iter >= 0:
|
||||
iteration_offset = max_iter + 1
|
||||
logger.info(
|
||||
"Session '%s' resuming with iteration_offset=%d (from events.jsonl max)",
|
||||
"Session '%s' resuming with iteration_offset=%d"
|
||||
" (from events.jsonl max), last phase: %s",
|
||||
session.id,
|
||||
iteration_offset,
|
||||
last_phase or "unknown",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
@@ -762,11 +1017,27 @@ class SessionManager:
|
||||
try:
|
||||
_meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
_agent_path = _meta.get("agent_path")
|
||||
_phase = _meta.get("phase")
|
||||
|
||||
if _agent_path and Path(_agent_path).exists():
|
||||
await self.load_worker(session.id, _agent_path)
|
||||
if session.phase_state:
|
||||
await session.phase_state.switch_to_staging(source="auto")
|
||||
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
|
||||
if _phase in ("staging", "running", None):
|
||||
# Agent fully built — load worker and resume
|
||||
await self.load_worker(session.id, _agent_path)
|
||||
if session.phase_state:
|
||||
await session.phase_state.switch_to_staging(source="auto")
|
||||
# Emit flowchart overlay so frontend can display it
|
||||
await self._emit_flowchart_on_restore(session, _agent_path)
|
||||
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
|
||||
elif _phase == "building":
|
||||
# Agent folder exists but incomplete — resume building
|
||||
if session.phase_state:
|
||||
session.phase_state.agent_path = _agent_path
|
||||
await session.phase_state.switch_to_building(source="auto")
|
||||
logger.info("Cold restore: resumed BUILDING phase for %s", _agent_path)
|
||||
elif _phase == "planning":
|
||||
if session.phase_state:
|
||||
session.phase_state.agent_path = _agent_path
|
||||
logger.info("Cold restore: PLANNING phase for %s", _agent_path)
|
||||
except Exception:
|
||||
logger.warning("Cold restore: failed to auto-load worker", exc_info=True)
|
||||
|
||||
@@ -776,10 +1047,17 @@ class SessionManager:
|
||||
_consolidation_session_dir = queen_dir
|
||||
|
||||
async def _on_compaction(_event) -> None:
|
||||
# Only consolidate on queen compactions — worker and subagent
|
||||
# compactions are frequent and don't warrant a memory update.
|
||||
if getattr(_event, "stream_id", None) != "queen":
|
||||
return
|
||||
from framework.agents.queen.queen_memory import consolidate_queen_memory
|
||||
|
||||
await consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
asyncio.create_task(
|
||||
consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
),
|
||||
name=f"queen-memory-consolidation-{session.id}",
|
||||
)
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
@@ -841,6 +1119,29 @@ class SessionManager:
|
||||
)
|
||||
)
|
||||
|
||||
async def _emit_flowchart_on_restore(self, session: Session, agent_path: str | Path) -> None:
|
||||
"""Emit FLOWCHART_MAP_UPDATED from persisted flowchart file on cold restore."""
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.tools.flowchart_utils import load_flowchart_file
|
||||
|
||||
original_draft, flowchart_map = load_flowchart_file(agent_path)
|
||||
if original_draft is None:
|
||||
return
|
||||
# Cache in phase_state so the REST endpoint also returns it
|
||||
if session.phase_state:
|
||||
session.phase_state.original_draft_graph = original_draft
|
||||
session.phase_state.flowchart_map = flowchart_map
|
||||
await session.event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.FLOWCHART_MAP_UPDATED,
|
||||
stream_id="queen",
|
||||
data={
|
||||
"map": flowchart_map,
|
||||
"original_draft": original_draft,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def _notify_queen_worker_unloaded(self, session: Session) -> None:
|
||||
"""Notify the queen that the worker has been unloaded."""
|
||||
executor = session.queen_executor
|
||||
@@ -868,6 +1169,10 @@ class SessionManager:
|
||||
event_type = (
|
||||
EventType.TRIGGER_AVAILABLE if kind == "available" else EventType.TRIGGER_REMOVED
|
||||
)
|
||||
# Resolve graph entry node for trigger target
|
||||
runner = getattr(session, "runner", None)
|
||||
graph_entry = runner.graph.entry_node if runner else None
|
||||
|
||||
for t in triggers.values():
|
||||
await session.event_bus.publish(
|
||||
AgentEvent(
|
||||
@@ -877,6 +1182,8 @@ class SessionManager:
|
||||
"trigger_id": t.id,
|
||||
"trigger_type": t.trigger_type,
|
||||
"trigger_config": t.trigger_config,
|
||||
"name": t.description or t.id,
|
||||
**({"entry_node": graph_entry} if graph_entry else {}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ Uses aiohttp TestClient with mocked sessions to test all endpoints
|
||||
without requiring actual LLM calls or agent loading.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -13,6 +14,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.runtime.triggers import TriggerDefinition
|
||||
from framework.server.app import create_app
|
||||
from framework.server.session_manager import Session
|
||||
|
||||
@@ -172,6 +174,7 @@ def _make_session(
|
||||
runner.intro_message = "Test intro"
|
||||
|
||||
mock_event_bus = MagicMock()
|
||||
mock_event_bus.publish = AsyncMock()
|
||||
mock_llm = MagicMock()
|
||||
|
||||
queen_executor = _make_queen_executor() if with_queen else None
|
||||
@@ -484,6 +487,70 @@ class TestSessionCRUD:
|
||||
data = await resp.json()
|
||||
assert "primary" in data["graphs"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_trigger_task(self, tmp_path):
|
||||
session = _make_session(tmp_dir=tmp_path)
|
||||
session.available_triggers["daily"] = TriggerDefinition(
|
||||
id="daily",
|
||||
trigger_type="timer",
|
||||
trigger_config={"cron": "0 5 * * *"},
|
||||
task="Old task",
|
||||
)
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
"/api/sessions/test_agent/triggers/daily",
|
||||
json={"task": "New task"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["task"] == "New task"
|
||||
assert data["trigger_config"]["cron"] == "0 5 * * *"
|
||||
assert session.available_triggers["daily"].task == "New task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_trigger_cron_restarts_active_timer(self, tmp_path):
|
||||
session = _make_session(tmp_dir=tmp_path)
|
||||
session.available_triggers["daily"] = TriggerDefinition(
|
||||
id="daily",
|
||||
trigger_type="timer",
|
||||
trigger_config={"cron": "0 5 * * *"},
|
||||
task="Run task",
|
||||
active=True,
|
||||
)
|
||||
session.active_trigger_ids.add("daily")
|
||||
session.active_timer_tasks["daily"] = asyncio.create_task(asyncio.sleep(60))
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
"/api/sessions/test_agent/triggers/daily",
|
||||
json={"trigger_config": {"cron": "0 6 * * *"}},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["trigger_config"]["cron"] == "0 6 * * *"
|
||||
assert "daily" in session.active_timer_tasks
|
||||
assert session.active_timer_tasks["daily"] is not None
|
||||
assert session.available_triggers["daily"].trigger_config["cron"] == "0 6 * * *"
|
||||
session.active_timer_tasks["daily"].cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_trigger_cron_rejects_invalid_expression(self, tmp_path):
|
||||
session = _make_session(tmp_dir=tmp_path)
|
||||
session.available_triggers["daily"] = TriggerDefinition(
|
||||
id="daily",
|
||||
trigger_type="timer",
|
||||
trigger_config={"cron": "0 5 * * *"},
|
||||
task="Run task",
|
||||
)
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
"/api/sessions/test_agent/triggers/daily",
|
||||
json={"trigger_config": {"cron": "not a cron"}},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
class TestExecution:
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Hive Agent Skills — discovery, parsing, and injection of SKILL.md packages.
|
||||
"""Hive Agent Skills — discovery, parsing, trust gating, and injection of SKILL.md packages.
|
||||
|
||||
Implements the open Agent Skills standard (agentskills.io) for portable
|
||||
skill discovery and activation, plus built-in default skills for runtime
|
||||
operational discipline.
|
||||
operational discipline, and AS-13 trust gating for project-scope skills.
|
||||
"""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
@@ -10,7 +10,10 @@ from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.models import TrustStatus
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillError, SkillErrorCode, log_skill_error
|
||||
from framework.skills.trust import TrustedRepoStore, TrustGate
|
||||
|
||||
__all__ = [
|
||||
"DefaultSkillConfig",
|
||||
@@ -22,5 +25,11 @@ __all__ = [
|
||||
"SkillsConfig",
|
||||
"SkillsManager",
|
||||
"SkillsManagerConfig",
|
||||
"TrustGate",
|
||||
"TrustedRepoStore",
|
||||
"TrustStatus",
|
||||
"parse_skill_md",
|
||||
"SkillError",
|
||||
"SkillErrorCode",
|
||||
"log_skill_error",
|
||||
]
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -76,6 +77,7 @@ class SkillCatalog:
|
||||
lines.append(f" <name>{escape(skill.name)}</name>")
|
||||
lines.append(f" <description>{escape(skill.description)}</description>")
|
||||
lines.append(f" <location>{escape(skill.location)}</location>")
|
||||
lines.append(f" <base_dir>{escape(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</available_skills>")
|
||||
|
||||
@@ -96,7 +98,14 @@ class SkillCatalog:
|
||||
for name in skill_names:
|
||||
skill = self.get(name)
|
||||
if skill is None:
|
||||
logger.warning("Pre-activated skill '%s' not found in catalog", name)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"Pre-activated skill '{name}' not found in catalog",
|
||||
why="The skill was listed for pre-activation but was not discovered.",
|
||||
fix=f"Check that a SKILL.md for '{name}' exists in a scanned directory.",
|
||||
)
|
||||
continue
|
||||
if self.is_activated(name):
|
||||
continue # Already activated, skip duplicate
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
"""CLI commands for the Hive skill system.
|
||||
|
||||
Phase 1 commands (AS-13):
|
||||
hive skill list — list discovered skills across all scopes
|
||||
hive skill trust <path> — permanently trust a project repo's skills
|
||||
|
||||
Full CLI suite (CLI-1 through CLI-13) is Phase 2.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def register_skill_commands(subparsers) -> None:
|
||||
"""Register the ``hive skill`` subcommand group."""
|
||||
skill_parser = subparsers.add_parser("skill", help="Manage skills")
|
||||
skill_sub = skill_parser.add_subparsers(dest="skill_command", required=True)
|
||||
|
||||
# hive skill list
|
||||
list_parser = skill_sub.add_parser("list", help="List discovered skills across all scopes")
|
||||
list_parser.add_argument(
|
||||
"--project-dir",
|
||||
default=None,
|
||||
metavar="PATH",
|
||||
help="Project directory to scan (default: current directory)",
|
||||
)
|
||||
list_parser.set_defaults(func=cmd_skill_list)
|
||||
|
||||
# hive skill trust
|
||||
trust_parser = skill_sub.add_parser(
|
||||
"trust",
|
||||
help="Permanently trust a project repository so its skills load without prompting",
|
||||
)
|
||||
trust_parser.add_argument(
|
||||
"project_path",
|
||||
help="Path to the project directory (must contain a .git with a remote origin)",
|
||||
)
|
||||
trust_parser.set_defaults(func=cmd_skill_trust)
|
||||
|
||||
|
||||
def cmd_skill_list(args) -> int:
|
||||
"""List all discovered skills grouped by scope."""
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
|
||||
project_dir = Path(args.project_dir).resolve() if args.project_dir else Path.cwd()
|
||||
skills = SkillDiscovery(DiscoveryConfig(project_root=project_dir)).discover()
|
||||
|
||||
if not skills:
|
||||
print("No skills discovered.")
|
||||
return 0
|
||||
|
||||
scope_headers = {
|
||||
"project": "PROJECT SKILLS",
|
||||
"user": "USER SKILLS",
|
||||
"framework": "FRAMEWORK SKILLS",
|
||||
}
|
||||
|
||||
for scope in ("project", "user", "framework"):
|
||||
scope_skills = [s for s in skills if s.source_scope == scope]
|
||||
if not scope_skills:
|
||||
continue
|
||||
print(f"\n{scope_headers[scope]}")
|
||||
print("─" * 40)
|
||||
for skill in scope_skills:
|
||||
print(f" • {skill.name}")
|
||||
print(f" {skill.description}")
|
||||
print(f" {skill.location}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_skill_trust(args) -> int:
|
||||
"""Permanently trust a project repository's skills."""
|
||||
from framework.skills.trust import TrustedRepoStore, _normalize_remote_url
|
||||
|
||||
project_path = Path(args.project_path).resolve()
|
||||
|
||||
if not project_path.exists():
|
||||
print(f"Error: path does not exist: {project_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if not (project_path / ".git").exists():
|
||||
print(
|
||||
f"Error: {project_path} is not a git repository (no .git directory).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(project_path), "remote", "get-url", "origin"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(
|
||||
"Error: no remote 'origin' configured in this repository.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
remote_url = result.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
print("Error: git remote lookup timed out.", file=sys.stderr)
|
||||
return 1
|
||||
except (FileNotFoundError, OSError) as e:
|
||||
print(f"Error reading git remote: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo_key = _normalize_remote_url(remote_url)
|
||||
store = TrustedRepoStore()
|
||||
store.trust(repo_key, project_path=str(project_path))
|
||||
|
||||
print(f"✓ Trusted: {repo_key}")
|
||||
print(" Stored in ~/.hive/trusted_repos.json")
|
||||
print(" Skills from this repository will load without prompting in future runs.")
|
||||
return 0
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,12 +61,14 @@ class DefaultSkillManager:
|
||||
self._config = config or SkillsConfig()
|
||||
self._skills: dict[str, ParsedSkill] = {}
|
||||
self._loaded = False
|
||||
self._error_count = 0
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load all enabled default skill SKILL.md files."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
error_count = 0
|
||||
for skill_name, dir_name in SKILL_REGISTRY.items():
|
||||
if not self._config.is_default_enabled(skill_name):
|
||||
logger.info("Default skill '%s' disabled by config", skill_name)
|
||||
@@ -73,17 +76,34 @@ class DefaultSkillManager:
|
||||
|
||||
skill_path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
if not skill_path.is_file():
|
||||
logger.error("Default skill SKILL.md not found: %s", skill_path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"Default skill SKILL.md not found: '{skill_path}'",
|
||||
why=f"The framework skill '{skill_name}' is missing its SKILL.md file.",
|
||||
fix="Reinstall the hive framework — this file is part of the package.",
|
||||
)
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
parsed = parse_skill_md(skill_path, source_scope="framework")
|
||||
if parsed is None:
|
||||
logger.error("Failed to parse default skill: %s", skill_path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Failed to parse default skill '{skill_name}'",
|
||||
why=f"parse_skill_md returned None for '{skill_path}'.",
|
||||
fix="Reinstall the hive framework — this file may be corrupted.",
|
||||
)
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
self._skills[skill_name] = parsed
|
||||
|
||||
self._loaded = True
|
||||
self._error_count = error_count
|
||||
|
||||
def build_protocols_prompt(self) -> str:
|
||||
"""Build the combined operational protocols section.
|
||||
@@ -127,8 +147,23 @@ class DefaultSkillManager:
|
||||
"""Log which default skills are active and their configuration."""
|
||||
if not self._skills:
|
||||
logger.info("Default skills: all disabled")
|
||||
return
|
||||
|
||||
# DX-3: Per-skill structured startup log
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
overrides = self._config.get_default_overrides(skill_name)
|
||||
status = f"loaded overrides={overrides}" if overrides else "loaded"
|
||||
elif not self._config.is_default_enabled(skill_name):
|
||||
status = "disabled"
|
||||
else:
|
||||
status = "error"
|
||||
logger.info(
|
||||
"skill_startup name=%s scope=framework status=%s",
|
||||
skill_name,
|
||||
status,
|
||||
)
|
||||
|
||||
# Original active skills log line (preserved for backward compatibility)
|
||||
active = []
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
@@ -138,7 +173,21 @@ class DefaultSkillManager:
|
||||
else:
|
||||
active.append(skill_name)
|
||||
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
if active:
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
|
||||
# DX-3: Summary line with error count
|
||||
total = len(SKILL_REGISTRY)
|
||||
active_count = len(self._skills)
|
||||
error_count = getattr(self, "_error_count", 0)
|
||||
disabled_count = total - active_count - error_count
|
||||
logger.info(
|
||||
"Skills: %d default (%d active, %d disabled, %d error)",
|
||||
total,
|
||||
active_count,
|
||||
disabled_count,
|
||||
error_count,
|
||||
)
|
||||
|
||||
@property
|
||||
def active_skill_names(self) -> list[str]:
|
||||
|
||||
@@ -11,6 +11,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -172,11 +173,13 @@ class SkillDiscovery:
|
||||
for skill in skills:
|
||||
if skill.name in seen:
|
||||
existing = seen[skill.name]
|
||||
logger.warning(
|
||||
"Skill name collision: '%s' from %s overrides %s",
|
||||
skill.name,
|
||||
skill.location,
|
||||
existing.location,
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_COLLISION,
|
||||
what=f"Skill name collision: '{skill.name}'",
|
||||
why=f"'{skill.location}' overrides '{existing.location}'.",
|
||||
fix="Rename one of the conflicting skill directories to use a unique name.",
|
||||
)
|
||||
seen[skill.name] = skill
|
||||
|
||||
|
||||
@@ -42,11 +42,14 @@ class SkillsManagerConfig:
|
||||
When ``None``, community discovery is skipped.
|
||||
skip_community_discovery: Explicitly skip community scanning
|
||||
even when ``project_root`` is set.
|
||||
interactive: Whether trust gating can prompt the user interactively.
|
||||
When ``False``, untrusted project skills are silently skipped.
|
||||
"""
|
||||
|
||||
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
|
||||
project_root: Path | None = None
|
||||
skip_community_discovery: bool = False
|
||||
interactive: bool = True
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
@@ -63,6 +66,7 @@ class SkillsManager:
|
||||
self._loaded = False
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
self._allowlisted_dirs: list[str] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
@@ -85,6 +89,7 @@ class SkillsManager:
|
||||
mgr._loaded = True # skip load()
|
||||
mgr._catalog_prompt = skills_catalog_prompt
|
||||
mgr._protocols_prompt = protocols_prompt
|
||||
mgr._allowlisted_dirs = []
|
||||
return mgr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -113,9 +118,18 @@ class SkillsManager:
|
||||
# 1. Community skill discovery (when project_root is available)
|
||||
catalog_prompt = ""
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
from framework.skills.trust import TrustGate
|
||||
|
||||
discovery = SkillDiscovery(DiscoveryConfig(project_root=self._config.project_root))
|
||||
discovered = discovery.discover()
|
||||
|
||||
# Trust-gate project-scope skills (AS-13)
|
||||
discovered = TrustGate(interactive=self._config.interactive).filter_and_gate(
|
||||
discovered, project_dir=self._config.project_root
|
||||
)
|
||||
|
||||
catalog = SkillCatalog(discovered)
|
||||
self._allowlisted_dirs = catalog.allowlisted_dirs
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
|
||||
# Pre-activated community skills
|
||||
@@ -132,6 +146,16 @@ class SkillsManager:
|
||||
default_mgr.load()
|
||||
default_mgr.log_active_skills()
|
||||
protocols_prompt = default_mgr.build_protocols_prompt()
|
||||
# DX-3: Community skill startup summary
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
community_count = len(catalog._skills) if catalog_prompt else 0
|
||||
pre_activated_count = len(skills_config.skills) if skills_config.skills else 0
|
||||
logger.info(
|
||||
"Skills: %d community (%d catalog, %d pre-activated)",
|
||||
community_count,
|
||||
community_count,
|
||||
pre_activated_count,
|
||||
)
|
||||
|
||||
# 3. Cache
|
||||
self._catalog_prompt = catalog_prompt
|
||||
@@ -160,6 +184,11 @@ class SkillsManager:
|
||||
"""Default skill operational protocols for system prompt injection."""
|
||||
return self._protocols_prompt
|
||||
|
||||
@property
|
||||
def allowlisted_dirs(self) -> list[str]:
|
||||
"""Skill base directories for Tier 3 resource access (AS-6)."""
|
||||
return self._allowlisted_dirs
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Data models for the Hive skill system (Agent Skills standard)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SkillScope(StrEnum):
|
||||
"""Where a skill was discovered."""
|
||||
|
||||
PROJECT = "project"
|
||||
USER = "user"
|
||||
FRAMEWORK = "framework"
|
||||
|
||||
|
||||
class TrustStatus(StrEnum):
|
||||
"""Trust state of a skill entry."""
|
||||
|
||||
TRUSTED = "trusted"
|
||||
PENDING_CONSENT = "pending_consent"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""In-memory record for a discovered skill (PRD §4.2)."""
|
||||
|
||||
name: str
|
||||
"""Skill name from SKILL.md frontmatter."""
|
||||
|
||||
description: str
|
||||
"""Skill description from SKILL.md frontmatter."""
|
||||
|
||||
location: Path
|
||||
"""Absolute path to SKILL.md."""
|
||||
|
||||
base_dir: Path
|
||||
"""Parent directory of SKILL.md (skill root)."""
|
||||
|
||||
source_scope: SkillScope
|
||||
"""Which scope this skill was found in."""
|
||||
|
||||
trust_status: TrustStatus = TrustStatus.TRUSTED
|
||||
"""Trust state; project-scope skills start as PENDING_CONSENT before gating."""
|
||||
|
||||
# Optional frontmatter fields
|
||||
license: str | None = None
|
||||
compatibility: list[str] = field(default_factory=list)
|
||||
allowed_tools: list[str] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
@@ -13,6 +13,8 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum name length before a warning is logged
|
||||
@@ -74,17 +76,38 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.error("Failed to read %s: %s", path, exc)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to read '{path}'",
|
||||
why=str(exc),
|
||||
fix="Check the file exists and has read permissions.",
|
||||
)
|
||||
return None
|
||||
|
||||
if not content.strip():
|
||||
logger.error("Empty SKILL.md: %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="The file exists but contains no content.",
|
||||
fix="Add valid YAML frontmatter and a markdown body to the SKILL.md.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Split on --- delimiters (first two occurrences)
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
logger.error("SKILL.md missing YAML frontmatter delimiters (---): %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="Missing YAML frontmatter (---).",
|
||||
fix="Wrap the frontmatter with --- on its own line at the top and bottom.",
|
||||
)
|
||||
return None
|
||||
|
||||
# parts[0] is content before first --- (should be empty or whitespace)
|
||||
@@ -94,7 +117,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
body = parts[2].strip()
|
||||
|
||||
if not raw_yaml:
|
||||
logger.error("Empty YAML frontmatter in %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="The --- delimiters are present but the YAML block is empty.",
|
||||
fix="Add at least 'name' and 'description' fields to the frontmatter.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse YAML
|
||||
@@ -108,19 +138,47 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
try:
|
||||
fixed = _try_fix_yaml(raw_yaml)
|
||||
frontmatter = yaml.safe_load(fixed)
|
||||
logger.warning("Fixed YAML parse issues in %s (unquoted colons)", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_YAML_FIXUP,
|
||||
what=f"Auto-fixed YAML in '{path}'",
|
||||
why="Unquoted colon values detected in frontmatter.",
|
||||
fix='Wrap values containing colons in quotes e.g. description: "Use for: research"',
|
||||
)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.error("Unparseable YAML in %s: %s", path, exc)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why=str(exc),
|
||||
fix="Validate the YAML frontmatter at https://yaml-online-parser.appspot.com/",
|
||||
)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter, dict):
|
||||
logger.error("YAML frontmatter is not a mapping in %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="YAML frontmatter is not a key-value mapping.",
|
||||
fix="Ensure the frontmatter is valid YAML with key: value pairs.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Required: description
|
||||
description = frontmatter.get("description")
|
||||
if not description or not str(description).strip():
|
||||
logger.error("Missing or empty 'description' in %s — skipping skill", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_MISSING_DESCRIPTION,
|
||||
what=f"Missing 'description' in '{path}'",
|
||||
why="The 'description' field is required but is absent or empty.",
|
||||
fix="Add a non-empty 'description' field to the YAML frontmatter.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Required: name (fallback to parent directory name)
|
||||
@@ -128,7 +186,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
parent_dir_name = path.parent.name
|
||||
if not name or not str(name).strip():
|
||||
name = parent_dir_name
|
||||
logger.warning("Missing 'name' in %s — using directory name '%s'", path, name)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NAME_MISMATCH,
|
||||
what=f"Missing 'name' in '{path}' — using directory name '{name}'",
|
||||
why="The 'name' field is absent from the YAML frontmatter.",
|
||||
fix=f"Add 'name: {name}' to the frontmatter to make this explicit.",
|
||||
)
|
||||
else:
|
||||
name = str(name).strip()
|
||||
|
||||
@@ -137,11 +202,13 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
logger.warning("Skill name exceeds %d chars in %s: '%s'", _MAX_NAME_LENGTH, path, name)
|
||||
|
||||
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
|
||||
logger.warning(
|
||||
"Skill name '%s' doesn't match parent directory '%s' in %s",
|
||||
name,
|
||||
parent_dir_name,
|
||||
path,
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NAME_MISMATCH,
|
||||
what=f"Name mismatch in '{path}'",
|
||||
why=f"Skill name '{name}' doesn't match directory '{parent_dir_name}'.",
|
||||
fix=f"Rename the directory to '{name}' or set name to '{parent_dir_name}'.",
|
||||
)
|
||||
|
||||
return ParsedSkill(
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Structured error codes and diagnostics for the Hive skill system.
|
||||
|
||||
Implements DX-1 (structured error codes) and DX-2 (what/why/fix format)
|
||||
from the skill system PRD §7.5.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SkillErrorCode(Enum):
|
||||
"""Standardized error codes for skill system operations (DX-1)."""
|
||||
|
||||
SKILL_NOT_FOUND = "SKILL_NOT_FOUND"
|
||||
SKILL_PARSE_ERROR = "SKILL_PARSE_ERROR"
|
||||
SKILL_ACTIVATION_FAILED = "SKILL_ACTIVATION_FAILED"
|
||||
SKILL_MISSING_DESCRIPTION = "SKILL_MISSING_DESCRIPTION"
|
||||
SKILL_YAML_FIXUP = "SKILL_YAML_FIXUP"
|
||||
SKILL_NAME_MISMATCH = "SKILL_NAME_MISMATCH"
|
||||
SKILL_COLLISION = "SKILL_COLLISION"
|
||||
|
||||
|
||||
class SkillError(Exception):
|
||||
"""Structured exception for skill system errors (DX-2).
|
||||
|
||||
Raised in strict validation paths. Also used as the base
|
||||
format contract for log_skill_error() log messages.
|
||||
"""
|
||||
|
||||
def __init__(self, code: SkillErrorCode, what: str, why: str, fix: str):
|
||||
self.code = code
|
||||
self.what = what
|
||||
self.why = why
|
||||
self.fix = fix
|
||||
self.message = (
|
||||
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def log_skill_error(
|
||||
logger: logging.Logger,
|
||||
level: str,
|
||||
code: SkillErrorCode,
|
||||
what: str,
|
||||
why: str,
|
||||
fix: str,
|
||||
) -> None:
|
||||
"""Emit a structured skill diagnostic log with consistent format (DX-2).
|
||||
|
||||
Args:
|
||||
logger: The module logger to emit to.
|
||||
level: Log level string — 'error', 'warning', or 'info'.
|
||||
code: Structured error code.
|
||||
what: What failed (specific skill name and path).
|
||||
why: Root cause.
|
||||
fix: Concrete next step for the developer.
|
||||
"""
|
||||
msg = f"[{code.value}] What failed: {what} | Why: {why} | Fix: {fix}"
|
||||
getattr(logger, level)(
|
||||
msg,
|
||||
extra={
|
||||
"skill_error_code": code.value,
|
||||
"what": what,
|
||||
"why": why,
|
||||
"fix": fix,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""Trust gating for project-level skills (PRD AS-13).
|
||||
|
||||
Project-level skills from untrusted repositories require explicit user consent
|
||||
before their instructions are loaded into the agent's system prompt.
|
||||
Framework and user-scope skills are always trusted.
|
||||
|
||||
Trusted repos are persisted at ~/.hive/trusted_repos.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Env var to bypass trust gating in CI/headless pipelines (opt-in).
|
||||
_ENV_TRUST_ALL = "HIVE_TRUST_PROJECT_SKILLS"
|
||||
|
||||
# Env var for comma-separated own-remote glob patterns (e.g. "github.com/myorg/*").
|
||||
_ENV_OWN_REMOTES = "HIVE_OWN_REMOTES"
|
||||
|
||||
_TRUSTED_REPOS_PATH = Path.home() / ".hive" / "trusted_repos.json"
|
||||
_NOTICE_SENTINEL_PATH = Path.home() / ".hive" / ".skill_trust_notice_shown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trusted repo store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrustedRepoEntry:
|
||||
repo_key: str
|
||||
added_at: datetime
|
||||
project_path: str = ""
|
||||
|
||||
|
||||
class TrustedRepoStore:
|
||||
"""Persists permanently-trusted repo keys to ~/.hive/trusted_repos.json."""
|
||||
|
||||
def __init__(self, path: Path | None = None) -> None:
|
||||
self._path = path or _TRUSTED_REPOS_PATH
|
||||
self._entries: dict[str, TrustedRepoEntry] = {}
|
||||
self._loaded = False
|
||||
|
||||
def is_trusted(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
return repo_key in self._entries
|
||||
|
||||
def trust(self, repo_key: str, project_path: str = "") -> None:
|
||||
self._ensure_loaded()
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=datetime.now(tz=UTC),
|
||||
project_path=project_path,
|
||||
)
|
||||
self._save()
|
||||
logger.info("skill_trust_store: trusted repo_key=%s", repo_key)
|
||||
|
||||
def revoke(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
if repo_key in self._entries:
|
||||
del self._entries[repo_key]
|
||||
self._save()
|
||||
logger.info("skill_trust_store: revoked repo_key=%s", repo_key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_entries(self) -> list[TrustedRepoEntry]:
|
||||
self._ensure_loaded()
|
||||
return list(self._entries.values())
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
if not self._loaded:
|
||||
self._load()
|
||||
self._loaded = True
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
data = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
for raw in data.get("entries", []):
|
||||
repo_key = raw.get("repo_key", "")
|
||||
if not repo_key:
|
||||
continue
|
||||
try:
|
||||
added_at = datetime.fromisoformat(raw["added_at"])
|
||||
except (KeyError, ValueError):
|
||||
added_at = datetime.now(tz=UTC)
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=added_at,
|
||||
project_path=raw.get("project_path", ""),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"skill_trust_store: could not read %s (%s); treating as empty",
|
||||
self._path,
|
||||
e,
|
||||
)
|
||||
|
||||
def _save(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"version": 1,
|
||||
"entries": [
|
||||
{
|
||||
"repo_key": e.repo_key,
|
||||
"added_at": e.added_at.isoformat(),
|
||||
"project_path": e.project_path,
|
||||
}
|
||||
for e in self._entries.values()
|
||||
],
|
||||
}
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp = self._path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
tmp.replace(self._path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProjectTrustClassification(StrEnum):
|
||||
ALWAYS_TRUSTED = "always_trusted"
|
||||
TRUSTED_BY_USER = "trusted_by_user"
|
||||
UNTRUSTED = "untrusted"
|
||||
|
||||
|
||||
class ProjectTrustDetector:
|
||||
"""Classifies a project directory as trusted or untrusted.
|
||||
|
||||
Algorithm (PRD §4.1 trust note):
|
||||
1. No project_dir → ALWAYS_TRUSTED
|
||||
2. No .git directory → ALWAYS_TRUSTED (not a git repo)
|
||||
3. No remote 'origin' → ALWAYS_TRUSTED (local-only repo)
|
||||
4. Remote URL → repo_key; in TrustedRepoStore → TRUSTED_BY_USER
|
||||
5. Localhost remote → ALWAYS_TRUSTED
|
||||
6. ~/.hive/own_remotes match → ALWAYS_TRUSTED
|
||||
7. HIVE_OWN_REMOTES env match → ALWAYS_TRUSTED
|
||||
8. None of the above → UNTRUSTED
|
||||
"""
|
||||
|
||||
def __init__(self, store: TrustedRepoStore | None = None) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
|
||||
def classify(self, project_dir: Path | None) -> tuple[ProjectTrustClassification, str]:
|
||||
"""Return (classification, repo_key).
|
||||
|
||||
repo_key is empty string for ALWAYS_TRUSTED cases without a remote.
|
||||
"""
|
||||
if project_dir is None or not project_dir.exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
if not (project_dir / ".git").exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
remote_url = self._get_remote_origin(project_dir)
|
||||
if not remote_url:
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
repo_key = _normalize_remote_url(remote_url)
|
||||
|
||||
# Explicitly trusted by user
|
||||
if self._store.is_trusted(repo_key):
|
||||
return ProjectTrustClassification.TRUSTED_BY_USER, repo_key
|
||||
|
||||
# Localhost remotes are always trusted
|
||||
if _is_localhost_remote(remote_url):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
# User-configured own-remote patterns
|
||||
if self._matches_own_remotes(repo_key):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
return ProjectTrustClassification.UNTRUSTED, repo_key
|
||||
|
||||
def _get_remote_origin(self, project_dir: Path) -> str:
|
||||
"""Run git remote get-url origin. Returns empty string on any failure."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(project_dir), "remote", "get-url", "origin"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
"skill_trust: git remote lookup timed out for %s; treating as trusted",
|
||||
project_dir,
|
||||
)
|
||||
except (FileNotFoundError, OSError):
|
||||
pass # git not found or other OS error
|
||||
return ""
|
||||
|
||||
def _matches_own_remotes(self, repo_key: str) -> bool:
|
||||
"""Check repo_key against user-configured own-remote glob patterns."""
|
||||
import fnmatch
|
||||
|
||||
patterns: list[str] = []
|
||||
|
||||
# From env var
|
||||
env_patterns = _ENV_OWN_REMOTES
|
||||
import os
|
||||
|
||||
raw = os.environ.get(env_patterns, "")
|
||||
if raw:
|
||||
patterns.extend(p.strip() for p in raw.split(",") if p.strip())
|
||||
|
||||
# From ~/.hive/own_remotes file
|
||||
own_remotes_file = Path.home() / ".hive" / "own_remotes"
|
||||
if own_remotes_file.is_file():
|
||||
try:
|
||||
for line in own_remotes_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
patterns.append(line)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return any(fnmatch.fnmatch(repo_key, p) for p in patterns)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL helpers (public so CLI can reuse)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_remote_url(url: str) -> str:
|
||||
"""Normalize a git remote URL to a canonical ``host/org/repo`` key.
|
||||
|
||||
Examples:
|
||||
git@github.com:org/repo.git → github.com/org/repo
|
||||
https://github.com/org/repo → github.com/org/repo
|
||||
ssh://git@github.com/org/repo.git → github.com/org/repo
|
||||
"""
|
||||
url = url.strip()
|
||||
|
||||
# SCP-style SSH: git@github.com:org/repo.git
|
||||
if url.startswith("git@") and ":" in url and "://" not in url:
|
||||
url = url[4:] # strip git@
|
||||
url = url.replace(":", "/", 1)
|
||||
elif "://" in url:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ""
|
||||
path = parsed.path.lstrip("/")
|
||||
url = f"{host}/{path}"
|
||||
|
||||
# Strip .git suffix
|
||||
if url.endswith(".git"):
|
||||
url = url[:-4]
|
||||
|
||||
return url.lower().strip("/")
|
||||
|
||||
|
||||
def _is_localhost_remote(remote_url: str) -> bool:
|
||||
"""Return True if the remote points to a local host."""
|
||||
local_hosts = {"localhost", "127.0.0.1", "::1"}
|
||||
try:
|
||||
if "://" in remote_url:
|
||||
parsed = urlparse(remote_url)
|
||||
return (parsed.hostname or "").lower() in local_hosts
|
||||
# SCP-style: git@localhost:org/repo
|
||||
if "@" in remote_url:
|
||||
host_part = remote_url.split("@", 1)[1].split(":")[0]
|
||||
return host_part.lower() in local_hosts
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TrustGate:
|
||||
"""Filters skill list, running consent flow for untrusted project-scope skills.
|
||||
|
||||
Framework and user-scope skills are always allowed through.
|
||||
Project-scope skills from untrusted repos require consent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: TrustedRepoStore | None = None,
|
||||
detector: ProjectTrustDetector | None = None,
|
||||
interactive: bool = True,
|
||||
print_fn: Callable[[str], None] | None = None,
|
||||
input_fn: Callable[[str], str] | None = None,
|
||||
) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
self._detector = detector or ProjectTrustDetector(self._store)
|
||||
self._interactive = interactive
|
||||
self._print = print_fn or print
|
||||
self._input = input_fn or input
|
||||
|
||||
def filter_and_gate(
|
||||
self,
|
||||
skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
) -> list[ParsedSkill]:
|
||||
"""Return the subset of skills that are trusted for loading.
|
||||
|
||||
- Framework and user-scope skills: always included.
|
||||
- Project-scope skills: classified; consent prompt shown if untrusted.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Separate project skills from always-trusted scopes
|
||||
always_trusted = [s for s in skills if s.source_scope != "project"]
|
||||
project_skills = [s for s in skills if s.source_scope == "project"]
|
||||
|
||||
if not project_skills:
|
||||
return always_trusted
|
||||
|
||||
# Env-var CI override: trust all project skills for this invocation
|
||||
if os.environ.get(_ENV_TRUST_ALL, "").strip() == "1":
|
||||
logger.info(
|
||||
"skill_trust: %s=1 set; trusting %d project skill(s) without consent",
|
||||
_ENV_TRUST_ALL,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
classification, repo_key = self._detector.classify(project_dir)
|
||||
|
||||
if classification in (
|
||||
ProjectTrustClassification.ALWAYS_TRUSTED,
|
||||
ProjectTrustClassification.TRUSTED_BY_USER,
|
||||
):
|
||||
logger.info(
|
||||
"skill_trust: project skills trusted classification=%s repo=%s count=%d",
|
||||
classification,
|
||||
repo_key or "(no remote)",
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
# UNTRUSTED — need consent
|
||||
if not self._interactive or not sys.stdin.isatty():
|
||||
logger.warning(
|
||||
"skill_trust: skipping %d project-scope skill(s) from untrusted repo "
|
||||
"'%s' (non-interactive mode). "
|
||||
"To trust permanently run: hive skill trust %s",
|
||||
len(project_skills),
|
||||
repo_key,
|
||||
project_dir or ".",
|
||||
)
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=denied mode=headless",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted
|
||||
|
||||
# Interactive consent flow
|
||||
decision = self._run_consent_flow(project_skills, project_dir, repo_key)
|
||||
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=%s mode=interactive",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
decision,
|
||||
)
|
||||
|
||||
if decision == "session":
|
||||
return always_trusted + project_skills
|
||||
|
||||
if decision == "permanent":
|
||||
self._store.trust(repo_key, project_path=str(project_dir or ""))
|
||||
return always_trusted + project_skills
|
||||
|
||||
# denied
|
||||
return always_trusted
|
||||
|
||||
def _run_consent_flow(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
) -> str:
|
||||
"""Show the security notice (once) and consent prompt.
|
||||
Return 'session' | 'permanent' | 'denied'."""
|
||||
from framework.credentials.setup import Colors
|
||||
|
||||
if not sys.stdout.isatty():
|
||||
Colors.disable()
|
||||
|
||||
self._maybe_show_security_notice(Colors)
|
||||
self._print_consent_prompt(project_skills, project_dir, repo_key, Colors)
|
||||
return self._prompt_consent(Colors)
|
||||
|
||||
def _maybe_show_security_notice(self, Colors) -> None: # noqa: N803
|
||||
"""Show the one-time security notice if not already shown (NFR-5)."""
|
||||
if _NOTICE_SENTINEL_PATH.exists():
|
||||
return
|
||||
self._print("")
|
||||
self._print(
|
||||
f"{Colors.YELLOW}Security notice:{Colors.NC} Skills inject instructions "
|
||||
"into the agent's system prompt."
|
||||
)
|
||||
self._print(
|
||||
" Only load skills from sources you trust. "
|
||||
"Registry skills at tier 'verified' or 'official' have been audited."
|
||||
)
|
||||
self._print("")
|
||||
try:
|
||||
_NOTICE_SENTINEL_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
_NOTICE_SENTINEL_PATH.touch()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _print_consent_prompt(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
Colors, # noqa: N803
|
||||
) -> None:
|
||||
p = self._print
|
||||
p("")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p(f"{Colors.BOLD} SKILL TRUST REQUIRED{Colors.NC}")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p("")
|
||||
proj_label = str(project_dir) if project_dir else "this project"
|
||||
p(
|
||||
f" The project at {Colors.CYAN}{proj_label}{Colors.NC} wants to load "
|
||||
f"{len(project_skills)} skill(s)"
|
||||
)
|
||||
p(" that will inject instructions into the agent's system prompt.")
|
||||
if repo_key:
|
||||
p(f" Source: {Colors.BOLD}{repo_key}{Colors.NC}")
|
||||
p("")
|
||||
p(" Skills requesting access:")
|
||||
for skill in project_skills:
|
||||
p(f" {Colors.CYAN}•{Colors.NC} {Colors.BOLD}{skill.name}{Colors.NC}")
|
||||
p(f' "{skill.description}"')
|
||||
p(f" {Colors.DIM}{skill.location}{Colors.NC}")
|
||||
p("")
|
||||
p(" Options:")
|
||||
p(f" {Colors.CYAN}1){Colors.NC} Trust this session only")
|
||||
p(f" {Colors.CYAN}2){Colors.NC} Trust permanently — remember for future runs")
|
||||
p(
|
||||
f" {Colors.DIM}3) Deny"
|
||||
f" — skip all project-scope skills from this repo{Colors.NC}"
|
||||
)
|
||||
p(f"{Colors.YELLOW}{'─' * 60}{Colors.NC}")
|
||||
|
||||
def _prompt_consent(self, Colors) -> str: # noqa: N803
|
||||
"""Prompt until a valid choice is entered. Returns 'session'|'permanent'|'denied'."""
|
||||
mapping = {"1": "session", "2": "permanent", "3": "denied"}
|
||||
while True:
|
||||
try:
|
||||
choice = self._input("Select option (1-3): ").strip()
|
||||
if choice in mapping:
|
||||
return mapping[choice]
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return "denied"
|
||||
self._print(f"{Colors.RED}Invalid choice. Enter 1, 2, or 3.{Colors.NC}")
|
||||
@@ -727,6 +727,25 @@ def _dissolve_planning_nodes(
|
||||
return converted, flowchart_map
|
||||
|
||||
|
||||
def _update_meta_json(session_manager, manager_session_id, updates: dict) -> None:
|
||||
"""Merge updates into the queen session's meta.json."""
|
||||
if session_manager is None or not manager_session_id:
|
||||
return
|
||||
srv_session = session_manager.get_session(manager_session_id)
|
||||
if not srv_session:
|
||||
return
|
||||
storage_sid = getattr(srv_session, "queen_resume_from", None) or srv_session.id
|
||||
meta_path = Path.home() / ".hive" / "queen" / "session" / storage_sid / "meta.json"
|
||||
try:
|
||||
existing = {}
|
||||
if meta_path.exists():
|
||||
existing = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
existing.update(updates)
|
||||
meta_path.write_text(json.dumps(existing), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def register_queen_lifecycle_tools(
|
||||
registry: ToolRegistry,
|
||||
session: Any = None,
|
||||
@@ -975,6 +994,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to building phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_building()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "building"
|
||||
@@ -1559,12 +1579,22 @@ def register_queen_lifecycle_tools(
|
||||
# Find edges where this leaf node is the source
|
||||
out_edges = [e for e in validated_edges if e["source"] == leaf_id]
|
||||
in_edges = [e for e in validated_edges if e["target"] == leaf_id]
|
||||
if not out_edges:
|
||||
continue # already a proper leaf
|
||||
|
||||
# Identify the parent (predecessor that connects IN)
|
||||
parent_ids = [e["source"] for e in in_edges]
|
||||
|
||||
if not out_edges:
|
||||
# Already a proper leaf — still ensure sub_agents is set
|
||||
for pid in parent_ids:
|
||||
parent = node_by_id_v.get(pid)
|
||||
if parent is None:
|
||||
continue
|
||||
existing = parent.get("sub_agents") or []
|
||||
if leaf_id not in existing:
|
||||
existing.append(leaf_id)
|
||||
parent["sub_agents"] = existing
|
||||
continue
|
||||
|
||||
# Strip all outgoing edges from the leaf node that
|
||||
# don't go back to a parent (report edges are OK)
|
||||
illegal_targets: list[str] = []
|
||||
@@ -1978,6 +2008,17 @@ def register_queen_lifecycle_tools(
|
||||
"type": "string",
|
||||
"description": "What success looks like for this node",
|
||||
},
|
||||
"sub_agents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"IDs of GCU/browser sub-agent nodes managed by this node. "
|
||||
"At build time, sub-agent nodes are dissolved into this list. "
|
||||
"Set this on the PARENT node — e.g. the orchestrator that "
|
||||
"delegates to GCU leaves. Visual delegation edges are "
|
||||
"synthesized automatically."
|
||||
),
|
||||
},
|
||||
"decision_clause": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
@@ -2095,8 +2136,22 @@ def register_queen_lifecycle_tools(
|
||||
phase_state.draft_graph = converted
|
||||
phase_state.flowchart_map = fmap
|
||||
|
||||
# Note: flowchart file is persisted later, in initialize_and_build_agent
|
||||
# (after the agent folder is scaffolded) or in load_built_agent.
|
||||
# Create agent folder early so flowchart and agent_path are available
|
||||
# throughout the entire BUILDING phase.
|
||||
_agent_name = phase_state.draft_graph.get("agent_name", "").strip()
|
||||
if _agent_name:
|
||||
_agent_folder = Path("exports") / _agent_name
|
||||
_agent_folder.mkdir(parents=True, exist_ok=True)
|
||||
_save_flowchart_file(_agent_folder, original_copy, fmap)
|
||||
phase_state.agent_path = str(_agent_folder)
|
||||
_update_meta_json(
|
||||
session_manager,
|
||||
manager_session_id,
|
||||
{
|
||||
"agent_path": str(_agent_folder),
|
||||
"agent_name": _agent_name.replace("_", " ").title(),
|
||||
},
|
||||
)
|
||||
|
||||
dissolved_count = len(original_nodes) - len(converted.get("nodes", []))
|
||||
decision_count = sum(1 for n in original_nodes if n.get("flowchart_type") == "decision")
|
||||
@@ -2228,6 +2283,7 @@ def register_queen_lifecycle_tools(
|
||||
if fallback_path:
|
||||
phase_state.agent_path = str(fallback_path)
|
||||
await phase_state.switch_to_building(source="tool")
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
|
||||
if phase_state.inject_notification:
|
||||
await phase_state.inject_notification(
|
||||
"[PHASE CHANGE] Switched to BUILDING phase. "
|
||||
@@ -2270,8 +2326,13 @@ def register_queen_lifecycle_tools(
|
||||
if parsed.get("success", True):
|
||||
if phase_state is not None:
|
||||
# Set agent_path so the frontend can query credentials
|
||||
phase_state.agent_path = str(Path("exports") / agent_name)
|
||||
phase_state.agent_path = phase_state.agent_path or str(
|
||||
Path("exports") / agent_name
|
||||
)
|
||||
await phase_state.switch_to_building(source="tool")
|
||||
_update_meta_json(
|
||||
session_manager, manager_session_id, {"phase": "building"}
|
||||
)
|
||||
# Reset draft state after successful scaffolding
|
||||
phase_state.build_confirmed = False
|
||||
# Persist flowchart now that the agent folder exists
|
||||
@@ -2319,6 +2380,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to staging phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_staging()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "staging"
|
||||
@@ -2347,6 +2409,30 @@ def register_queen_lifecycle_tools(
|
||||
"""Get the session's event bus for querying history."""
|
||||
return getattr(session, "event_bus", None)
|
||||
|
||||
def _get_worker_name() -> str | None:
|
||||
"""Return the worker agent directory name, used for diary lookups."""
|
||||
p = getattr(session, "worker_path", None)
|
||||
return p.name if p else None
|
||||
|
||||
def _format_diary(max_runs: int) -> str:
|
||||
"""Read recent run digests from disk — no EventBus required."""
|
||||
agent_name = _get_worker_name()
|
||||
if not agent_name:
|
||||
return "No worker loaded — diary unavailable."
|
||||
from framework.agents.worker_memory import read_recent_digests
|
||||
|
||||
entries = read_recent_digests(agent_name, max_runs)
|
||||
if not entries:
|
||||
return (
|
||||
f"No run digests for '{agent_name}' yet. "
|
||||
"Digests are written at the end of each completed run."
|
||||
)
|
||||
lines = [f"Worker '{agent_name}' — {len(entries)} recent run digest(s):", ""]
|
||||
for _run_id, content in entries:
|
||||
lines.append(content)
|
||||
lines.append("")
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
# Tiered cooldowns: summary is free, detail has short cooldown, full keeps 30s
|
||||
_COOLDOWN_FULL = 30.0
|
||||
_COOLDOWN_DETAIL = 10.0
|
||||
@@ -2949,16 +3035,17 @@ def register_queen_lifecycle_tools(
|
||||
import time as _time
|
||||
|
||||
# --- Tiered cooldown ---
|
||||
# diary is free (file reads only), summary is free, detail has 10s, full has 30s
|
||||
now = _time.monotonic()
|
||||
if focus == "full":
|
||||
cooldown = _COOLDOWN_FULL
|
||||
tier = "full"
|
||||
elif focus is not None:
|
||||
elif focus == "diary" or focus is None:
|
||||
cooldown = 0.0
|
||||
tier = focus or "summary"
|
||||
else:
|
||||
cooldown = _COOLDOWN_DETAIL
|
||||
tier = "detail"
|
||||
else:
|
||||
cooldown = 0.0
|
||||
tier = "summary"
|
||||
|
||||
elapsed_since = now - _status_last_called.get(tier, 0.0)
|
||||
if elapsed_since < cooldown:
|
||||
@@ -2974,6 +3061,10 @@ def register_queen_lifecycle_tools(
|
||||
)
|
||||
_status_last_called[tier] = now
|
||||
|
||||
# --- Diary: pure file reads, no runtime required ---
|
||||
if focus == "diary":
|
||||
return _format_diary(last_n)
|
||||
|
||||
# --- Runtime check ---
|
||||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
@@ -3023,7 +3114,7 @@ def register_queen_lifecycle_tools(
|
||||
else:
|
||||
return (
|
||||
f"Unknown focus '{focus}'. "
|
||||
"Valid options: activity, memory, tools, issues, progress, full."
|
||||
"Valid options: diary, activity, memory, tools, issues, progress, full."
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("get_worker_status error")
|
||||
@@ -3034,6 +3125,8 @@ def register_queen_lifecycle_tools(
|
||||
description=(
|
||||
"Check on the worker. Returns a brief prose summary by default. "
|
||||
"Use 'focus' to drill into specifics:\n"
|
||||
"- diary: persistent run digests from past executions — read this first "
|
||||
"before digging into live runtime logs\n"
|
||||
"- activity: current node, transitions, latest LLM output\n"
|
||||
"- memory: worker's accumulated knowledge and state\n"
|
||||
"- tools: running and recent tool calls\n"
|
||||
@@ -3046,8 +3139,11 @@ def register_queen_lifecycle_tools(
|
||||
"properties": {
|
||||
"focus": {
|
||||
"type": "string",
|
||||
"enum": ["activity", "memory", "tools", "issues", "progress", "full"],
|
||||
"description": ("Aspect to inspect. Omit for a brief summary."),
|
||||
"enum": ["diary", "activity", "memory", "tools", "issues", "progress", "full"],
|
||||
"description": (
|
||||
"Aspect to inspect. Omit for a brief summary. "
|
||||
"Use 'diary' to read persistent run history before checking live logs."
|
||||
),
|
||||
},
|
||||
"last_n": {
|
||||
"type": "integer",
|
||||
@@ -3446,6 +3542,7 @@ def register_queen_lifecycle_tools(
|
||||
if phase_state is not None:
|
||||
phase_state.agent_path = str(resolved_path)
|
||||
await phase_state.switch_to_staging()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
|
||||
|
||||
worker_name = info.name if info else updated_session.worker_id
|
||||
return json.dumps(
|
||||
@@ -3565,6 +3662,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to running phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_running()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "running"})
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
@@ -3702,6 +3800,8 @@ def register_queen_lifecycle_tools(
|
||||
_save_trigger_to_agent(session, trigger_id, tdef)
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus:
|
||||
_runner = getattr(session, "runner", None)
|
||||
_graph_entry = _runner.graph.entry_node if _runner else None
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_ACTIVATED,
|
||||
@@ -3710,6 +3810,8 @@ def register_queen_lifecycle_tools(
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": t_type,
|
||||
"trigger_config": t_config,
|
||||
"name": tdef.description or trigger_id,
|
||||
**({"entry_node": _graph_entry} if _graph_entry else {}),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -3762,6 +3864,8 @@ def register_queen_lifecycle_tools(
|
||||
# Emit event
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus:
|
||||
_runner = getattr(session, "runner", None)
|
||||
_graph_entry = _runner.graph.entry_node if _runner else None
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_ACTIVATED,
|
||||
@@ -3770,6 +3874,8 @@ def register_queen_lifecycle_tools(
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": t_type,
|
||||
"trigger_config": t_config,
|
||||
"name": tdef.description or trigger_id,
|
||||
**({"entry_node": _graph_entry} if _graph_entry else {}),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -3868,7 +3974,10 @@ def register_queen_lifecycle_tools(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_DEACTIVATED,
|
||||
stream_id="queen",
|
||||
data={"trigger_id": trigger_id},
|
||||
data={
|
||||
"trigger_id": trigger_id,
|
||||
"name": tdef.description or trigger_id if tdef else trigger_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -34,8 +34,8 @@ export const executionApi = {
|
||||
graph_id: graphId,
|
||||
}),
|
||||
|
||||
chat: (sessionId: string, message: string) =>
|
||||
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message }),
|
||||
chat: (sessionId: string, message: string, images?: { type: string; image_url: { url: string } }[]) =>
|
||||
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message, ...(images?.length ? { images } : {}) }),
|
||||
|
||||
/** Queue context for the queen without triggering an LLM response. */
|
||||
queenContext: (sessionId: string, message: string) =>
|
||||
|
||||
@@ -64,10 +64,14 @@ export const sessionsApi = {
|
||||
`/sessions/${sessionId}/entry-points`,
|
||||
),
|
||||
|
||||
updateTriggerTask: (sessionId: string, triggerId: string, task: string) =>
|
||||
api.patch<{ trigger_id: string; task: string }>(
|
||||
updateTrigger: (
|
||||
sessionId: string,
|
||||
triggerId: string,
|
||||
patch: { task?: string; trigger_config?: Record<string, unknown> },
|
||||
) =>
|
||||
api.patch<{ trigger_id: string; task: string; trigger_config: Record<string, unknown> }>(
|
||||
`/sessions/${sessionId}/triggers/${triggerId}`,
|
||||
{ task },
|
||||
patch,
|
||||
),
|
||||
|
||||
graphs: (sessionId: string) =>
|
||||
@@ -77,6 +81,10 @@ export const sessionsApi = {
|
||||
eventsHistory: (sessionId: string) =>
|
||||
api.get<{ events: AgentEvent[]; session_id: string }>(`/sessions/${sessionId}/events/history`),
|
||||
|
||||
/** Open the session's data folder in the OS file manager. */
|
||||
revealFolder: (sessionId: string) =>
|
||||
api.post<{ path: string }>(`/sessions/${sessionId}/reveal`),
|
||||
|
||||
/** List all queen sessions on disk — live + cold (post-restart). */
|
||||
history: () =>
|
||||
api.get<{ sessions: Array<{ session_id: string; cold: boolean; live: boolean; has_messages: boolean; created_at: number; agent_name?: string | null; agent_path?: string | null }> }>("/sessions/history"),
|
||||
|
||||
@@ -14,6 +14,8 @@ export interface LiveSession {
|
||||
intro_message?: string;
|
||||
/** Queen operating phase — "planning", "building", "staging", or "running" */
|
||||
queen_phase?: "planning" | "building" | "staging" | "running";
|
||||
/** Whether the queen's LLM supports image content in messages */
|
||||
queen_supports_images?: boolean;
|
||||
/** Present in 409 conflict responses when worker is still loading */
|
||||
loading?: boolean;
|
||||
}
|
||||
@@ -324,6 +326,7 @@ export type EventTypeName =
|
||||
| "node_retry"
|
||||
| "edge_traversed"
|
||||
| "context_compacted"
|
||||
| "context_usage_updated"
|
||||
| "webhook_received"
|
||||
| "custom"
|
||||
| "escalation_requested"
|
||||
@@ -337,7 +340,8 @@ export type EventTypeName =
|
||||
| "trigger_activated"
|
||||
| "trigger_deactivated"
|
||||
| "trigger_fired"
|
||||
| "trigger_removed";
|
||||
| "trigger_removed"
|
||||
| "trigger_updated";
|
||||
|
||||
export interface AgentEvent {
|
||||
type: EventTypeName;
|
||||
|
||||
@@ -1,8 +1,32 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { Send, Square, Crown, Cpu, Check, Loader2 } from "lucide-react";
|
||||
import { memo, useState, useRef, useEffect, useMemo } from "react";
|
||||
import {
|
||||
Send,
|
||||
Square,
|
||||
Crown,
|
||||
Cpu,
|
||||
Check,
|
||||
Loader2,
|
||||
Paperclip,
|
||||
X,
|
||||
} from "lucide-react";
|
||||
|
||||
export interface ImageContent {
|
||||
type: "image_url";
|
||||
image_url: { url: string };
|
||||
}
|
||||
|
||||
export interface ContextUsageEntry {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
import QuestionWidget from "@/components/QuestionWidget";
|
||||
import MultiQuestionWidget from "@/components/MultiQuestionWidget";
|
||||
import ParallelSubagentBubble, {
|
||||
type SubagentGroup,
|
||||
} from "@/components/ParallelSubagentBubble";
|
||||
|
||||
export interface ChatMessage {
|
||||
id: string;
|
||||
@@ -10,7 +34,13 @@ export interface ChatMessage {
|
||||
agentColor: string;
|
||||
content: string;
|
||||
timestamp: string;
|
||||
type?: "system" | "agent" | "user" | "tool_status" | "worker_input_request" | "run_divider";
|
||||
type?:
|
||||
| "system"
|
||||
| "agent"
|
||||
| "user"
|
||||
| "tool_status"
|
||||
| "worker_input_request"
|
||||
| "run_divider";
|
||||
role?: "queen" | "worker";
|
||||
/** Which worker thread this message belongs to (worker agent name) */
|
||||
thread?: string;
|
||||
@@ -18,11 +48,17 @@ export interface ChatMessage {
|
||||
createdAt?: number;
|
||||
/** Queen phase active when this message was created */
|
||||
phase?: "planning" | "building" | "staging" | "running";
|
||||
/** Images attached to a user message */
|
||||
images?: ImageContent[];
|
||||
/** Backend node_id that produced this message — used for subagent grouping */
|
||||
nodeId?: string;
|
||||
/** Backend execution_id for this message */
|
||||
executionId?: string;
|
||||
}
|
||||
|
||||
interface ChatPanelProps {
|
||||
messages: ChatMessage[];
|
||||
onSend: (message: string, thread: string) => void;
|
||||
onSend: (message: string, thread: string, images?: ImageContent[]) => void;
|
||||
isWaiting?: boolean;
|
||||
/** When true a worker is thinking (not yet streaming) */
|
||||
isWorkerWaiting?: boolean;
|
||||
@@ -31,6 +67,8 @@ interface ChatPanelProps {
|
||||
activeThread: string;
|
||||
/** When true, the input is disabled (e.g. during loading) */
|
||||
disabled?: boolean;
|
||||
/** When false, the image attach button is hidden (model lacks vision support) */
|
||||
supportsImages?: boolean;
|
||||
/** Called when user clicks the stop button to cancel the queen's current turn */
|
||||
onCancel?: () => void;
|
||||
/** Pending question from ask_user — replaces textarea when present */
|
||||
@@ -38,7 +76,9 @@ interface ChatPanelProps {
|
||||
/** Options for the pending question */
|
||||
pendingOptions?: string[] | null;
|
||||
/** Multiple questions from ask_user_multiple */
|
||||
pendingQuestions?: { id: string; prompt: string; options?: string[] }[] | null;
|
||||
pendingQuestions?:
|
||||
| { id: string; prompt: string; options?: string[] }[]
|
||||
| null;
|
||||
/** Called when user submits an answer to the pending question */
|
||||
onQuestionSubmit?: (answer: string, isOther: boolean) => void;
|
||||
/** Called when user submits answers to multiple questions */
|
||||
@@ -47,6 +87,8 @@ interface ChatPanelProps {
|
||||
onQuestionDismiss?: () => void;
|
||||
/** Queen operating phase — shown as a tag on queen messages */
|
||||
queenPhase?: "planning" | "building" | "staging" | "running";
|
||||
/** Context window usage for queen and workers */
|
||||
contextUsage?: Record<string, ContextUsageEntry>;
|
||||
}
|
||||
|
||||
const queenColor = "hsl(45,95%,58%)";
|
||||
@@ -72,7 +114,8 @@ const TOOL_HEX = [
|
||||
|
||||
function toolHex(name: string): string {
|
||||
let hash = 0;
|
||||
for (let i = 0; i < name.length; i++) hash = (hash * 31 + name.charCodeAt(i)) | 0;
|
||||
for (let i = 0; i < name.length; i++)
|
||||
hash = (hash * 31 + name.charCodeAt(i)) | 0;
|
||||
return TOOL_HEX[Math.abs(hash) % TOOL_HEX.length];
|
||||
}
|
||||
|
||||
@@ -120,12 +163,18 @@ function ToolActivityRow({ content }: { content: string }) {
|
||||
<span
|
||||
key={`run-${p.name}`}
|
||||
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
|
||||
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
|
||||
style={{
|
||||
color: hex,
|
||||
backgroundColor: `${hex}18`,
|
||||
border: `1px solid ${hex}35`,
|
||||
}}
|
||||
>
|
||||
<Loader2 className="w-2.5 h-2.5 animate-spin" />
|
||||
{p.name}
|
||||
{p.count > 1 && (
|
||||
<span className="text-[10px] font-medium opacity-70">×{p.count}</span>
|
||||
<span className="text-[10px] font-medium opacity-70">
|
||||
×{p.count}
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
@@ -136,7 +185,11 @@ function ToolActivityRow({ content }: { content: string }) {
|
||||
<span
|
||||
key={`done-${p.name}`}
|
||||
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
|
||||
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
|
||||
style={{
|
||||
color: hex,
|
||||
backgroundColor: `${hex}18`,
|
||||
border: `1px solid ${hex}35`,
|
||||
}}
|
||||
>
|
||||
<Check className="w-2.5 h-2.5" />
|
||||
{p.name}
|
||||
@@ -151,109 +204,249 @@ function ToolActivityRow({ content }: { content: string }) {
|
||||
);
|
||||
}
|
||||
|
||||
const MessageBubble = memo(function MessageBubble({ msg, queenPhase }: { msg: ChatMessage; queenPhase?: "planning" | "building" | "staging" | "running" }) {
|
||||
const isUser = msg.type === "user";
|
||||
const isQueen = msg.role === "queen";
|
||||
const color = getColor(msg.agent, msg.role);
|
||||
const MessageBubble = memo(
|
||||
function MessageBubble({
|
||||
msg,
|
||||
queenPhase,
|
||||
}: {
|
||||
msg: ChatMessage;
|
||||
queenPhase?: "planning" | "building" | "staging" | "running";
|
||||
}) {
|
||||
const isUser = msg.type === "user";
|
||||
const isQueen = msg.role === "queen";
|
||||
const color = getColor(msg.agent, msg.role);
|
||||
|
||||
if (msg.type === "run_divider") {
|
||||
return (
|
||||
<div className="flex items-center gap-3 py-2 my-1">
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
|
||||
{msg.content}
|
||||
</span>
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (msg.type === "system") {
|
||||
return (
|
||||
<div className="flex justify-center py-1">
|
||||
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
|
||||
{msg.content}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (msg.type === "tool_status") {
|
||||
return <ToolActivityRow content={msg.content} />;
|
||||
}
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
<div className="flex justify-end">
|
||||
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
|
||||
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
|
||||
if (msg.type === "run_divider") {
|
||||
return (
|
||||
<div className="flex items-center gap-3 py-2 my-1">
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
|
||||
{msg.content}
|
||||
</span>
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
|
||||
style={{
|
||||
backgroundColor: `${color}18`,
|
||||
border: `1.5px solid ${color}35`,
|
||||
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
|
||||
}}
|
||||
>
|
||||
{isQueen ? (
|
||||
<Crown className="w-4 h-4" style={{ color }} />
|
||||
) : (
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color }} />
|
||||
)}
|
||||
</div>
|
||||
<div className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}>
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`} style={{ color }}>
|
||||
{msg.agent}
|
||||
</span>
|
||||
<span
|
||||
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
|
||||
isQueen ? "bg-primary/15 text-primary" : "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{isQueen
|
||||
? ((msg.phase ?? queenPhase) === "running"
|
||||
? "running"
|
||||
: (msg.phase ?? queenPhase) === "staging"
|
||||
? "staging"
|
||||
: (msg.phase ?? queenPhase) === "planning"
|
||||
? "planning"
|
||||
: "building")
|
||||
: "Worker"}
|
||||
if (msg.type === "system") {
|
||||
return (
|
||||
<div className="flex justify-center py-1">
|
||||
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
|
||||
{msg.content}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (msg.type === "tool_status") {
|
||||
return <ToolActivityRow content={msg.content} />;
|
||||
}
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
<div className="flex justify-end">
|
||||
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
|
||||
{msg.images && msg.images.length > 0 && (
|
||||
<div className="flex flex-wrap gap-2 mb-2">
|
||||
{msg.images.map((img, i) => (
|
||||
<img
|
||||
key={i}
|
||||
src={img.image_url.url}
|
||||
alt={`attachment ${i + 1}`}
|
||||
className="max-h-48 max-w-full rounded-lg object-contain"
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{msg.content && (
|
||||
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
|
||||
style={{
|
||||
backgroundColor: `${color}18`,
|
||||
border: `1.5px solid ${color}35`,
|
||||
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
|
||||
}}
|
||||
>
|
||||
{isQueen ? (
|
||||
<Crown className="w-4 h-4" style={{ color }} />
|
||||
) : (
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color }} />
|
||||
)}
|
||||
</div>
|
||||
<div
|
||||
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
|
||||
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
|
||||
}`}
|
||||
className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}
|
||||
>
|
||||
<MarkdownContent content={msg.content} />
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span
|
||||
className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`}
|
||||
style={{ color }}
|
||||
>
|
||||
{msg.agent}
|
||||
</span>
|
||||
<span
|
||||
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
|
||||
isQueen
|
||||
? "bg-primary/15 text-primary"
|
||||
: "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{isQueen
|
||||
? (msg.phase ?? queenPhase) === "running"
|
||||
? "running"
|
||||
: (msg.phase ?? queenPhase) === "staging"
|
||||
? "staging"
|
||||
: (msg.phase ?? queenPhase) === "planning"
|
||||
? "planning"
|
||||
: "building"
|
||||
: "Worker"}
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
|
||||
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
|
||||
}`}
|
||||
>
|
||||
<MarkdownContent content={msg.content} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}, (prev, next) => prev.msg.id === next.msg.id && prev.msg.content === next.msg.content && prev.msg.phase === next.msg.phase && prev.queenPhase === next.queenPhase);
|
||||
);
|
||||
},
|
||||
(prev, next) =>
|
||||
prev.msg.id === next.msg.id &&
|
||||
prev.msg.content === next.msg.content &&
|
||||
prev.msg.phase === next.msg.phase &&
|
||||
prev.queenPhase === next.queenPhase,
|
||||
);
|
||||
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase }: ChatPanelProps) {
|
||||
export default function ChatPanel({
|
||||
messages,
|
||||
onSend,
|
||||
isWaiting,
|
||||
isWorkerWaiting,
|
||||
isBusy,
|
||||
activeThread,
|
||||
disabled,
|
||||
onCancel,
|
||||
pendingQuestion,
|
||||
pendingOptions,
|
||||
pendingQuestions,
|
||||
onQuestionSubmit,
|
||||
onMultiQuestionSubmit,
|
||||
onQuestionDismiss,
|
||||
queenPhase,
|
||||
contextUsage,
|
||||
supportsImages = true,
|
||||
}: ChatPanelProps) {
|
||||
const [input, setInput] = useState("");
|
||||
const [pendingImages, setPendingImages] = useState<ImageContent[]>([]);
|
||||
const [readMap, setReadMap] = useState<Record<string, number>>({});
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const stickToBottom = useRef(true);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const threadMessages = messages.filter((m) => {
|
||||
if (m.type === "system" && !m.thread) return false;
|
||||
return m.thread === activeThread;
|
||||
if (m.thread !== activeThread) return false;
|
||||
// Hide queen messages whose content is whitespace-only — these are
|
||||
// tool-use-only turns that have no visible text. During live operation
|
||||
// tool pills provide context, but on resume the pills are gone so
|
||||
// the empty bubble is meaningless.
|
||||
if (m.role === "queen" && !m.type && (!m.content || !m.content.trim()))
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
|
||||
// Group subagent messages into parallel bubbles.
|
||||
// A subagent message has nodeId containing ":subagent:".
|
||||
// The run only ends on hard boundaries (user messages, run_dividers)
|
||||
// so interleaved queen/tool/system messages don't fragment the bubble.
|
||||
type RenderItem =
|
||||
| { kind: "message"; msg: ChatMessage }
|
||||
| { kind: "parallel"; groupId: string; groups: SubagentGroup[] };
|
||||
|
||||
const renderItems = useMemo<RenderItem[]>(() => {
|
||||
const items: RenderItem[] = [];
|
||||
let i = 0;
|
||||
while (i < threadMessages.length) {
|
||||
const msg = threadMessages[i];
|
||||
const isSubagent = msg.nodeId?.includes(":subagent:");
|
||||
if (!isSubagent) {
|
||||
items.push({ kind: "message", msg });
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Start a subagent run. Collect all subagent messages, allowing
|
||||
// non-subagent messages in between (they render as normal items
|
||||
// before the bubble). Only break on hard boundaries.
|
||||
const subagentMsgs: ChatMessage[] = [];
|
||||
const interleaved: { idx: number; msg: ChatMessage }[] = [];
|
||||
const firstId = msg.id;
|
||||
|
||||
while (i < threadMessages.length) {
|
||||
const m = threadMessages[i];
|
||||
const isSa = m.nodeId?.includes(":subagent:");
|
||||
|
||||
if (isSa) {
|
||||
subagentMsgs.push(m);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Hard boundary — stop the run
|
||||
if (m.type === "user" || m.type === "run_divider") break;
|
||||
|
||||
// Worker message from a non-subagent node means the graph has
|
||||
// moved on to the next stage. Close the bubble even if some
|
||||
// subagents are still streaming in the background.
|
||||
if (m.role === "worker" && m.nodeId && !m.nodeId.includes(":subagent:"))
|
||||
break;
|
||||
|
||||
// Soft interruption (queen output, system, tool_status without
|
||||
// nodeId) — render it normally but keep the subagent run going
|
||||
interleaved.push({ idx: items.length + interleaved.length, msg: m });
|
||||
i++;
|
||||
}
|
||||
|
||||
// Emit interleaved messages first (before the bubble)
|
||||
for (const { msg: im } of interleaved) {
|
||||
items.push({ kind: "message", msg: im });
|
||||
}
|
||||
|
||||
// Build the single parallel bubble from all collected subagent msgs
|
||||
if (subagentMsgs.length > 0) {
|
||||
const byNode = new Map<string, ChatMessage[]>();
|
||||
for (const m of subagentMsgs) {
|
||||
const nid = m.nodeId!;
|
||||
if (!byNode.has(nid)) byNode.set(nid, []);
|
||||
byNode.get(nid)!.push(m);
|
||||
}
|
||||
const groups: SubagentGroup[] = [];
|
||||
for (const [nodeId, msgs] of byNode) {
|
||||
groups.push({
|
||||
nodeId,
|
||||
messages: msgs,
|
||||
contextUsage: contextUsage?.[nodeId],
|
||||
});
|
||||
}
|
||||
items.push({ kind: "parallel", groupId: `par-${firstId}`, groups });
|
||||
}
|
||||
}
|
||||
return items;
|
||||
}, [threadMessages, contextUsage]);
|
||||
|
||||
// Mark current thread as read
|
||||
useEffect(() => {
|
||||
const count = messages.filter((m) => m.thread === activeThread).length;
|
||||
@@ -284,26 +477,64 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!input.trim()) return;
|
||||
onSend(input.trim(), activeThread);
|
||||
if (!input.trim() && pendingImages.length === 0) return;
|
||||
onSend(
|
||||
input.trim(),
|
||||
activeThread,
|
||||
pendingImages.length > 0 ? pendingImages : undefined,
|
||||
);
|
||||
setInput("");
|
||||
setPendingImages([]);
|
||||
if (textareaRef.current) textareaRef.current.style.height = "auto";
|
||||
};
|
||||
|
||||
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const files = Array.from(e.target.files ?? []);
|
||||
if (files.length === 0) return;
|
||||
files.forEach((file) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (ev) => {
|
||||
const url = ev.target?.result as string;
|
||||
setPendingImages((prev) => [
|
||||
...prev,
|
||||
{ type: "image_url", image_url: { url } },
|
||||
]);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
// Reset so the same file can be re-selected
|
||||
e.target.value = "";
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full min-w-0">
|
||||
{/* Compact sub-header */}
|
||||
<div className="px-5 pt-4 pb-2 flex items-center gap-2">
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">Conversation</p>
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">
|
||||
Conversation
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Messages */}
|
||||
<div ref={scrollRef} onScroll={handleScroll} className="flex-1 overflow-auto px-5 py-4 space-y-3">
|
||||
{threadMessages.map((msg) => (
|
||||
<div key={msg.id}>
|
||||
<MessageBubble msg={msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
))}
|
||||
<div
|
||||
ref={scrollRef}
|
||||
onScroll={handleScroll}
|
||||
className="flex-1 overflow-auto px-5 py-4 space-y-3"
|
||||
>
|
||||
{renderItems.map((item) =>
|
||||
item.kind === "parallel" ? (
|
||||
<div key={item.groupId}>
|
||||
<ParallelSubagentBubble
|
||||
groupId={item.groupId}
|
||||
groups={item.groups}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div key={item.msg.id}>
|
||||
<MessageBubble msg={item.msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
),
|
||||
)}
|
||||
|
||||
{/* Show typing indicator while waiting for first queen response (disabled + empty chat) */}
|
||||
{(isWaiting || (disabled && threadMessages.length === 0)) && (
|
||||
@@ -320,9 +551,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
</div>
|
||||
<div className="border border-primary/20 bg-primary/5 rounded-2xl rounded-tl-md px-4 py-3">
|
||||
<div className="flex gap-1.5">
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "0ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "150ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "300ms" }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -340,9 +580,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
</div>
|
||||
<div className="bg-muted/60 rounded-2xl rounded-tl-md px-4 py-3">
|
||||
<div className="flex gap-1.5">
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "0ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "150ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "300ms" }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -350,8 +599,99 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
<div ref={bottomRef} />
|
||||
</div>
|
||||
|
||||
{/* Context window usage bar — sits between messages and input */}
|
||||
{(() => {
|
||||
if (!contextUsage) return null;
|
||||
const queenUsage = contextUsage["__queen__"];
|
||||
const workerEntries = Object.entries(contextUsage).filter(
|
||||
([k]) => k !== "__queen__",
|
||||
);
|
||||
const workerUsage =
|
||||
workerEntries.length > 0
|
||||
? workerEntries.reduce(
|
||||
(best, [, v]) => (v.usagePct > best.usagePct ? v : best),
|
||||
workerEntries[0][1],
|
||||
)
|
||||
: undefined;
|
||||
if (!queenUsage && !workerUsage) return null;
|
||||
return (
|
||||
<div className="flex items-center gap-3 mx-4 px-3 py-1 rounded-lg bg-muted/30 border border-border/20 group/ctx flex-shrink-0">
|
||||
{queenUsage && (
|
||||
<div
|
||||
className="flex items-center gap-2 flex-1 min-w-0"
|
||||
title={`Queen: ${(queenUsage.estimatedTokens / 1000).toFixed(1)}k / ${(queenUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${queenUsage.messageCount} messages`}
|
||||
>
|
||||
<Crown
|
||||
className="w-3 h-3 flex-shrink-0"
|
||||
style={{ color: "hsl(45,95%,58%)" }}
|
||||
/>
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(queenUsage.usagePct, 100)}%`,
|
||||
backgroundColor:
|
||||
queenUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: queenUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">
|
||||
{queenUsage.usagePct}%
|
||||
</span>
|
||||
<span className="hidden group-hover/ctx:inline">
|
||||
{(queenUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
|
||||
{(queenUsage.maxTokens / 1000).toFixed(0)}k
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{workerUsage && (
|
||||
<div
|
||||
className="flex items-center gap-2 flex-1 min-w-0"
|
||||
title={`Worker: ${(workerUsage.estimatedTokens / 1000).toFixed(1)}k / ${(workerUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${workerUsage.messageCount} messages`}
|
||||
>
|
||||
<Cpu
|
||||
className="w-3 h-3 flex-shrink-0"
|
||||
style={{ color: "hsl(220,60%,55%)" }}
|
||||
/>
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(workerUsage.usagePct, 100)}%`,
|
||||
backgroundColor:
|
||||
workerUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: workerUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(220,60%,55%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">
|
||||
{workerUsage.usagePct}%
|
||||
</span>
|
||||
<span className="hidden group-hover/ctx:inline">
|
||||
{(workerUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
|
||||
{(workerUsage.maxTokens / 1000).toFixed(0)}k
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{/* Input area — question widget replaces textarea when a question is pending */}
|
||||
{pendingQuestions && pendingQuestions.length >= 2 && onMultiQuestionSubmit ? (
|
||||
{pendingQuestions &&
|
||||
pendingQuestions.length >= 2 &&
|
||||
onMultiQuestionSubmit ? (
|
||||
<MultiQuestionWidget
|
||||
questions={pendingQuestions}
|
||||
onSubmit={onMultiQuestionSubmit}
|
||||
@@ -366,7 +706,47 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
/>
|
||||
) : (
|
||||
<form onSubmit={handleSubmit} className="p-4">
|
||||
{/* Image preview strip */}
|
||||
{pendingImages.length > 0 && (
|
||||
<div className="flex flex-wrap gap-2 mb-2 px-1">
|
||||
{pendingImages.map((img, i) => (
|
||||
<div key={i} className="relative group">
|
||||
<img
|
||||
src={img.image_url.url}
|
||||
alt={`preview ${i + 1}`}
|
||||
className="h-16 w-16 object-cover rounded-lg border border-border"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setPendingImages((prev) => prev.filter((_, j) => j !== i))
|
||||
}
|
||||
className="absolute -top-1.5 -right-1.5 w-4 h-4 rounded-full bg-destructive text-destructive-foreground flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
>
|
||||
<X className="w-2.5 h-2.5" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center gap-3 bg-muted/40 rounded-xl px-4 py-2.5 border border-border focus-within:border-primary/40 transition-colors">
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept="image/*"
|
||||
multiple
|
||||
className="hidden"
|
||||
onChange={handleFileChange}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
disabled={disabled || !supportsImages}
|
||||
onClick={() => supportsImages && fileInputRef.current?.click()}
|
||||
className="flex-shrink-0 p-1 rounded-md text-muted-foreground hover:text-foreground disabled:opacity-30 transition-colors"
|
||||
title={supportsImages ? "Attach image" : "Image not supported by the current model"}
|
||||
>
|
||||
<Paperclip className="w-4 h-4" />
|
||||
</button>
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
rows={1}
|
||||
@@ -383,7 +763,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
handleSubmit(e);
|
||||
}
|
||||
}}
|
||||
placeholder={disabled ? "Connecting to agent..." : "Message Queen Bee..."}
|
||||
placeholder={
|
||||
disabled ? "Connecting to agent..." : "Message Queen Bee..."
|
||||
}
|
||||
disabled={disabled}
|
||||
className="flex-1 bg-transparent text-sm text-foreground outline-none placeholder:text-muted-foreground disabled:opacity-50 disabled:cursor-not-allowed resize-none overflow-y-auto"
|
||||
/>
|
||||
@@ -398,7 +780,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
) : (
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!input.trim() || disabled}
|
||||
disabled={
|
||||
(!input.trim() && pendingImages.length === 0) || disabled
|
||||
}
|
||||
className="p-2 rounded-lg bg-primary text-primary-foreground disabled:opacity-30 hover:opacity-90 transition-opacity"
|
||||
>
|
||||
<Send className="w-4 h-4" />
|
||||
|
||||
@@ -3,11 +3,23 @@ import { Loader2 } from "lucide-react";
|
||||
import type { DraftGraph as DraftGraphData, DraftNode } from "@/api/types";
|
||||
import { RunButton } from "./RunButton";
|
||||
import type { GraphNode, RunState } from "./graph-types";
|
||||
import {
|
||||
cssVar,
|
||||
truncateLabel,
|
||||
TRIGGER_ICONS,
|
||||
ACTIVE_TRIGGER_COLORS,
|
||||
useTriggerColors,
|
||||
} from "@/lib/graphUtils";
|
||||
|
||||
// Read a CSS custom property value (space-separated HSL components)
|
||||
function cssVar(name: string): string {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
||||
}
|
||||
// ── Trigger layout constants ──
|
||||
const TRIGGER_H = 38; // pill height
|
||||
const TRIGGER_PILL_GAP_X = 16; // horizontal gap between multiple trigger pills
|
||||
const TRIGGER_ICON_X = 16; // icon center offset from pill left edge
|
||||
const TRIGGER_LABEL_X = 30; // label start offset from pill left edge
|
||||
const TRIGGER_LABEL_INSET = 38; // icon + padding subtracted from pill width for label space
|
||||
const TRIGGER_TEXT_Y = 11; // y-offset below pill for first text line (countdown or status)
|
||||
const TRIGGER_TEXT_STEP = 11; // additional y-offset for second text line when countdown present
|
||||
const TRIGGER_CLEARANCE = 30; // vertical space below pill for countdown + status text
|
||||
|
||||
interface DraftChromeColors {
|
||||
edge: string;
|
||||
@@ -107,13 +119,6 @@ function formatNodeId(id: string): string {
|
||||
return id.split("-").map(w => w.charAt(0).toUpperCase() + w.slice(1)).join(" ");
|
||||
}
|
||||
|
||||
function truncateLabel(label: string, availablePx: number, fontSize: number): string {
|
||||
const avgCharW = fontSize * 0.58;
|
||||
const maxChars = Math.floor(availablePx / avgCharW);
|
||||
if (label.length <= maxChars) return label;
|
||||
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
|
||||
}
|
||||
|
||||
/** Return the bounding-rect corner radius for a given flowchart shape. */
|
||||
/**
|
||||
* Render an ISO 5807 flowchart shape as an SVG element.
|
||||
@@ -240,6 +245,13 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
|
||||
const runBtnRef = useRef<HTMLButtonElement>(null);
|
||||
const [containerW, setContainerW] = useState(484);
|
||||
const chrome = useDraftChromeColors();
|
||||
const triggerColors = useTriggerColors();
|
||||
|
||||
// Extract trigger nodes from runtimeNodes
|
||||
const triggerNodes = useMemo(
|
||||
() => (runtimeNodes ?? []).filter(n => n.nodeType === "trigger"),
|
||||
[runtimeNodes],
|
||||
);
|
||||
|
||||
// ── Entrance animation — fires when originalDraft becomes a new non-null value ──
|
||||
// This covers: agent loaded, build finished, queen modifies flowchart.
|
||||
@@ -709,12 +721,17 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
|
||||
return { nodeYOffset: offsets, totalExtraY: totalExtra, groupBoxMaxX: maxGroupX };
|
||||
}, [nodes, maxLayer, flowchartMap, idxMap, layers, nodeXPositions, nodeW]);
|
||||
|
||||
// When triggers are present, push the entire draft graph down to make room
|
||||
const triggerOffsetY = triggerNodes.length > 0
|
||||
? TRIGGER_H + TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP + TRIGGER_CLEARANCE
|
||||
: 0;
|
||||
|
||||
const nodePos = (i: number) => ({
|
||||
x: nodeXPositions[i],
|
||||
y: TOP_Y + layers[i] * (NODE_H + GAP_Y) + nodeYOffset[i],
|
||||
y: TOP_Y + triggerOffsetY + layers[i] * (NODE_H + GAP_Y) + nodeYOffset[i],
|
||||
});
|
||||
|
||||
const svgHeight = TOP_Y + (maxLayer + 1) * NODE_H + maxLayer * GAP_Y + totalExtraY + 16;
|
||||
const svgHeight = TOP_Y + triggerOffsetY + (maxLayer + 1) * NODE_H + maxLayer * GAP_Y + totalExtraY + 16;
|
||||
|
||||
// Compute group areas for runtime node boundaries on the draft
|
||||
const groupAreas = useMemo(() => {
|
||||
@@ -847,6 +864,131 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
|
||||
pending: "",
|
||||
};
|
||||
|
||||
// ── Trigger node rendering ──
|
||||
|
||||
const triggerW = Math.min(nodeW, 180);
|
||||
|
||||
// Shared trigger pill X position (used by both node and edge renderers)
|
||||
const triggerPillX = (idx: number) => {
|
||||
const totalW = triggerNodes.length * triggerW + (triggerNodes.length - 1) * TRIGGER_PILL_GAP_X;
|
||||
return (containerW - totalW) / 2 + idx * (triggerW + TRIGGER_PILL_GAP_X);
|
||||
};
|
||||
|
||||
const renderTriggerNode = (node: GraphNode, triggerIdx: number) => {
|
||||
const icon = TRIGGER_ICONS[node.triggerType || ""] || "\u26A1";
|
||||
const isActive = node.status === "running" || node.status === "complete";
|
||||
const colors = isActive ? ACTIVE_TRIGGER_COLORS : triggerColors;
|
||||
const nextFireIn = node.triggerConfig?.next_fire_in as number | undefined;
|
||||
|
||||
const tx = triggerPillX(triggerIdx);
|
||||
const ty = TOP_Y;
|
||||
|
||||
const fontSize = triggerW < 140 ? 10.5 : 11.5;
|
||||
const displayLabel = truncateLabel(node.label, triggerW - TRIGGER_LABEL_INSET, fontSize);
|
||||
|
||||
// Countdown
|
||||
let countdownLabel: string | null = null;
|
||||
if (isActive && nextFireIn != null && nextFireIn > 0) {
|
||||
const h = Math.floor(nextFireIn / 3600);
|
||||
const m = Math.floor((nextFireIn % 3600) / 60);
|
||||
const s = Math.floor(nextFireIn % 60);
|
||||
countdownLabel = h > 0
|
||||
? `next in ${h}h ${String(m).padStart(2, "0")}m`
|
||||
: `next in ${m}m ${String(s).padStart(2, "0")}s`;
|
||||
}
|
||||
|
||||
const statusLabel = isActive ? "active" : "inactive";
|
||||
const statusColor = isActive ? "hsl(140,40%,50%)" : "hsl(210,20%,40%)";
|
||||
|
||||
return (
|
||||
<g
|
||||
key={node.id}
|
||||
onClick={() => onRuntimeNodeClick?.(node.id)}
|
||||
style={{ cursor: onRuntimeNodeClick ? "pointer" : "default" }}
|
||||
>
|
||||
<title>{node.label}</title>
|
||||
{/* Pill-shaped background */}
|
||||
<rect
|
||||
x={tx} y={ty}
|
||||
width={triggerW} height={TRIGGER_H}
|
||||
rx={TRIGGER_H / 2}
|
||||
fill={colors.bg}
|
||||
stroke={colors.border}
|
||||
strokeWidth={isActive ? 1.5 : 1}
|
||||
strokeDasharray={isActive ? undefined : "4 2"}
|
||||
/>
|
||||
{/* Icon */}
|
||||
<text
|
||||
x={tx + TRIGGER_ICON_X} y={ty + TRIGGER_H / 2}
|
||||
fill={colors.icon} fontSize={13}
|
||||
textAnchor="middle" dominantBaseline="middle"
|
||||
>
|
||||
{icon}
|
||||
</text>
|
||||
{/* Label */}
|
||||
<text
|
||||
x={tx + TRIGGER_LABEL_X} y={ty + TRIGGER_H / 2}
|
||||
fill={colors.text}
|
||||
fontSize={fontSize}
|
||||
fontWeight={500}
|
||||
dominantBaseline="middle"
|
||||
letterSpacing="0.01em"
|
||||
>
|
||||
{displayLabel}
|
||||
</text>
|
||||
{/* Countdown */}
|
||||
{countdownLabel && (
|
||||
<text
|
||||
x={tx + triggerW / 2} y={ty + TRIGGER_H + TRIGGER_TEXT_Y}
|
||||
fill={colors.text} fontSize={9}
|
||||
textAnchor="middle" fontStyle="italic" opacity={0.7}
|
||||
>
|
||||
{countdownLabel}
|
||||
</text>
|
||||
)}
|
||||
{/* Status */}
|
||||
<text
|
||||
x={tx + triggerW / 2} y={ty + TRIGGER_H + (countdownLabel ? TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP : TRIGGER_TEXT_Y)}
|
||||
fill={statusColor} fontSize={8.5}
|
||||
textAnchor="middle" opacity={0.8}
|
||||
>
|
||||
{statusLabel}
|
||||
</text>
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderTriggerEdge = (triggerIdx: number) => {
|
||||
if (nodes.length === 0) return null;
|
||||
const triggerNode = triggerNodes[triggerIdx];
|
||||
const runtimeTargetId = triggerNode?.next?.[0];
|
||||
const targetDraftId = runtimeTargetId
|
||||
? flowchartMap?.[runtimeTargetId]?.[0] ?? runtimeTargetId
|
||||
: draft?.entry_node;
|
||||
const targetIdx = targetDraftId ? idxMap[targetDraftId] ?? 0 : 0;
|
||||
const targetPos = nodePos(targetIdx);
|
||||
const targetX = targetPos.x + nodeW / 2;
|
||||
const targetY = targetPos.y;
|
||||
|
||||
const tx = triggerPillX(triggerIdx) + triggerW / 2;
|
||||
const ty = TOP_Y + TRIGGER_H + TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP + 4;
|
||||
|
||||
const midY = (ty + targetY) / 2;
|
||||
const d = Math.abs(tx - targetX) < 2
|
||||
? `M ${tx} ${ty} L ${targetX} ${targetY}`
|
||||
: `M ${tx} ${ty} L ${tx} ${midY} L ${targetX} ${midY} L ${targetX} ${targetY}`;
|
||||
|
||||
return (
|
||||
<g key={`trigger-edge-${triggerIdx}`}>
|
||||
<path d={d} fill="none" stroke={chrome.edge} strokeWidth={1.2} strokeDasharray="4 3" />
|
||||
<polygon
|
||||
points={`${targetX - 3},${targetY - 5} ${targetX + 3},${targetY - 5} ${targetX},${targetY - 1}`}
|
||||
fill={chrome.edgeArrow}
|
||||
/>
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderNode = (node: DraftNode, i: number) => {
|
||||
const pos = nodePos(i);
|
||||
const isHovered = hoveredNode === node.id;
|
||||
@@ -994,7 +1136,7 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
|
||||
>
|
||||
<svg
|
||||
width="100%"
|
||||
viewBox={`0 0 ${Math.max((maxContentRight ?? 0), groupBoxMaxX) + (backEdgeOverflow ?? 0)} ${totalH}`}
|
||||
viewBox={`0 0 ${Math.max((maxContentRight ?? 0), groupBoxMaxX, triggerNodes.length > 0 ? triggerPillX(triggerNodes.length - 1) + triggerW : 0) + (backEdgeOverflow ?? 0)} ${totalH}`}
|
||||
preserveAspectRatio="xMidYMin meet"
|
||||
className="select-none"
|
||||
style={{
|
||||
@@ -1078,6 +1220,11 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Trigger edges (dashed lines from trigger pills to first draft node) */}
|
||||
{triggerNodes.map((_, i) => renderTriggerEdge(i))}
|
||||
{/* Trigger pill nodes */}
|
||||
{triggerNodes.map((tn, i) => renderTriggerNode(tn, i))}
|
||||
|
||||
{forwardEdges.map((e, i) => renderEdge(e, i))}
|
||||
{backEdges.map((e, i) => renderBackEdge(e, i))}
|
||||
{nodes.map((n, i) => renderNode(n, i))}
|
||||
|
||||
@@ -28,6 +28,13 @@ export interface SubagentReport {
|
||||
status?: "running" | "complete" | "error";
|
||||
}
|
||||
|
||||
interface ContextUsage {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
interface NodeDetailPanelProps {
|
||||
node: GraphNode | null;
|
||||
nodeSpec?: NodeSpec | null;
|
||||
@@ -38,6 +45,7 @@ interface NodeDetailPanelProps {
|
||||
workerSessionId?: string | null;
|
||||
nodeLogs?: string[];
|
||||
actionPlan?: string;
|
||||
contextUsage?: ContextUsage;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
@@ -309,7 +317,7 @@ const tabs: { id: Tab; label: string; Icon: React.FC<{ className?: string }> }[]
|
||||
{ id: "subagents", label: "Subagents", Icon: ({ className }) => <Bot className={className} /> },
|
||||
];
|
||||
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, contextUsage, onClose }: NodeDetailPanelProps) {
|
||||
const [activeTab, setActiveTab] = useState<Tab>("overview");
|
||||
const [realTools, setRealTools] = useState<ToolInfo[] | null>(null);
|
||||
const [realCriteria, setRealCriteria] = useState<NodeCriteria | null>(null);
|
||||
@@ -389,6 +397,43 @@ export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagent
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Context window usage */}
|
||||
{contextUsage && (
|
||||
<div className="px-4 py-2 border-b border-border/20 flex-shrink-0">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="text-[10px] text-muted-foreground font-medium">Context</span>
|
||||
<span className="text-[10px] text-muted-foreground/70 ml-auto">
|
||||
{(contextUsage.estimatedTokens / 1000).toFixed(1)}k / {(contextUsage.maxTokens / 1000).toFixed(0)}k tokens
|
||||
</span>
|
||||
</div>
|
||||
<div className="w-full h-1.5 rounded-full bg-muted/50 overflow-hidden">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(contextUsage.usagePct, 100)}%`,
|
||||
backgroundColor: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center gap-2 mt-1">
|
||||
<span className="text-[10px] text-muted-foreground/60">{contextUsage.messageCount} messages</span>
|
||||
<span className="text-[10px] font-medium ml-auto" style={{
|
||||
color: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}>
|
||||
{contextUsage.usagePct}%
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Tab bar */}
|
||||
<div className="flex border-b border-border/30 flex-shrink-0 px-2 pt-1 overflow-x-auto scrollbar-hide">
|
||||
{tabs.filter(t => t.id !== "subagents" || (nodeSpec?.sub_agents && nodeSpec.sub_agents.length > 0)).map(tab => (
|
||||
|
||||
@@ -0,0 +1,413 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { ChevronDown, ChevronUp, Cpu } from "lucide-react";
|
||||
import type { ChatMessage, ContextUsageEntry } from "@/components/ChatPanel";
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const workerColor = "hsl(220,60%,55%)";
|
||||
|
||||
const SUBAGENT_COLORS = [
|
||||
"hsl(220,60%,55%)",
|
||||
"hsl(260,50%,55%)",
|
||||
"hsl(180,50%,45%)",
|
||||
"hsl(30,70%,50%)",
|
||||
"hsl(340,55%,50%)",
|
||||
"hsl(150,45%,45%)",
|
||||
"hsl(45,80%,50%)",
|
||||
"hsl(290,45%,55%)",
|
||||
];
|
||||
|
||||
function colorForIndex(i: number): string {
|
||||
return SUBAGENT_COLORS[i % SUBAGENT_COLORS.length];
|
||||
}
|
||||
|
||||
function subagentLabel(nodeId: string): string {
|
||||
const parts = nodeId.split(":subagent:");
|
||||
const raw = parts.length >= 2 ? parts[1] : nodeId;
|
||||
return raw
|
||||
.replace(/:\d+$/, "") // strip instance suffix like ":3"
|
||||
.replace(/[_-]/g, " ")
|
||||
.replace(/\b\w/g, (c) => c.toUpperCase())
|
||||
.trim();
|
||||
}
|
||||
|
||||
function last<T>(arr: T[]): T | undefined {
|
||||
return arr[arr.length - 1];
|
||||
}
|
||||
|
||||
export interface SubagentGroup {
|
||||
nodeId: string;
|
||||
messages: ChatMessage[];
|
||||
contextUsage?: ContextUsageEntry;
|
||||
}
|
||||
|
||||
interface ParallelSubagentBubbleProps {
|
||||
groups: SubagentGroup[];
|
||||
groupId: string;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Thermometer — vertical context gauge on right edge of each pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool overlay — shown when a tool_status message is active (not all done)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ToolOverlay({
|
||||
toolName,
|
||||
color,
|
||||
visible,
|
||||
}: {
|
||||
toolName: string;
|
||||
color: string;
|
||||
visible: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
className="absolute inset-0 top-[22px] flex items-center justify-center transition-opacity duration-200 z-10"
|
||||
style={{
|
||||
background: "rgba(8,8,14,0.82)",
|
||||
opacity: visible ? 1 : 0,
|
||||
pointerEvents: visible ? "auto" : "none",
|
||||
}}
|
||||
>
|
||||
<div className="text-center px-3 py-2 rounded-md border" style={{ borderColor: `${color}40` }}>
|
||||
<div className="text-[10px] font-medium" style={{ color }}>
|
||||
{toolName}
|
||||
</div>
|
||||
<div className="text-[11px] mt-0.5" style={{ color }}>
|
||||
{visible ? "..." : "\u2713"}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Single tmux pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function MuxPane({
|
||||
group,
|
||||
index,
|
||||
label,
|
||||
isFocused,
|
||||
isZoomed,
|
||||
onClickTitle,
|
||||
}: {
|
||||
group: SubagentGroup;
|
||||
index: number;
|
||||
label: string;
|
||||
isFocused: boolean;
|
||||
isZoomed: boolean;
|
||||
onClickTitle: () => void;
|
||||
}) {
|
||||
const bodyRef = useRef<HTMLDivElement>(null);
|
||||
const stickRef = useRef(true);
|
||||
const color = colorForIndex(index);
|
||||
const pct = group.contextUsage?.usagePct ?? 0;
|
||||
|
||||
const streamMsgs = group.messages.filter((m) => m.type !== "tool_status");
|
||||
const latestContent = last(streamMsgs)?.content ?? "";
|
||||
const msgCount = streamMsgs.length;
|
||||
|
||||
// Detect active tool and finished state from latest tool_status
|
||||
const latestTool = last(
|
||||
group.messages.filter((m) => m.type === "tool_status")
|
||||
);
|
||||
let activeToolName = "";
|
||||
let toolRunning = false;
|
||||
let isFinished = false;
|
||||
if (latestTool) {
|
||||
try {
|
||||
const parsed = JSON.parse(latestTool.content);
|
||||
const tools: { name: string; done: boolean }[] = parsed.tools || [];
|
||||
const allDone = parsed.allDone as boolean | undefined;
|
||||
const running = tools.find((t) => !t.done);
|
||||
if (running) {
|
||||
activeToolName = running.name;
|
||||
toolRunning = true;
|
||||
}
|
||||
// Finished when all tools are done and one of them is set_output
|
||||
// or report_to_parent (terminal tool calls)
|
||||
if (allDone && tools.length > 0) {
|
||||
const hasTerminal = tools.some(
|
||||
(t) =>
|
||||
t.done &&
|
||||
(t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
if (hasTerminal) isFinished = true;
|
||||
}
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-scroll
|
||||
useEffect(() => {
|
||||
if (stickRef.current && bodyRef.current) {
|
||||
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
|
||||
}
|
||||
}, [latestContent]);
|
||||
|
||||
const handleScroll = () => {
|
||||
const el = bodyRef.current;
|
||||
if (!el) return;
|
||||
stickRef.current = el.scrollHeight - el.scrollTop - el.clientHeight < 30;
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col min-h-0 overflow-hidden relative transition-all duration-200"
|
||||
style={{
|
||||
borderWidth: 1,
|
||||
borderStyle: "solid",
|
||||
borderColor: isFocused && !isFinished ? `${color}60` : "transparent",
|
||||
opacity: isFinished ? 0.4 : isFocused || isZoomed ? 1 : 0.55,
|
||||
...(isZoomed
|
||||
? { gridColumn: "1 / -1", gridRow: "1 / -1", zIndex: 10 }
|
||||
: {}),
|
||||
}}
|
||||
>
|
||||
{/* Title bar */}
|
||||
<div
|
||||
className="flex items-center gap-1.5 px-2 py-[3px] flex-shrink-0 cursor-pointer select-none"
|
||||
style={{ background: "#0e0e16", borderBottom: "1px solid #1a1a2a" }}
|
||||
onClick={onClickTitle}
|
||||
>
|
||||
{isFinished ? (
|
||||
<span className="text-[8px] flex-shrink-0 leading-none" style={{ color: "#4a4" }}>✓</span>
|
||||
) : (
|
||||
<div
|
||||
className="w-[6px] h-[6px] rounded-full flex-shrink-0"
|
||||
style={{ background: color }}
|
||||
/>
|
||||
)}
|
||||
<span className="text-[9px] flex-shrink-0" style={{ color: isFinished ? "#555" : color }}>
|
||||
{label}
|
||||
</span>
|
||||
<span className="flex-1" />
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{msgCount}
|
||||
</span>
|
||||
<div
|
||||
className="w-[36px] h-[3px] rounded-full overflow-hidden flex-shrink-0"
|
||||
style={{ background: "#1a1a2a" }}
|
||||
>
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500"
|
||||
style={{
|
||||
width: `${Math.min(pct, 100)}%`,
|
||||
backgroundColor:
|
||||
pct >= 80 ? "hsl(0,65%,55%)" : pct >= 50 ? "hsl(35,90%,55%)" : color,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{pct}%
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Body */}
|
||||
<div
|
||||
ref={bodyRef}
|
||||
onScroll={handleScroll}
|
||||
className="flex-1 min-h-0 overflow-y-auto px-2 py-1 text-[10px] leading-[1.7]"
|
||||
style={{ background: "#08080e", color: "#555", fontFamily: "monospace" }}
|
||||
>
|
||||
{latestContent ? (
|
||||
<div style={{ color: "#ccc" }}>
|
||||
<MarkdownContent content={latestContent} />
|
||||
</div>
|
||||
) : (
|
||||
<span style={{ color: "#333" }}>waiting...</span>
|
||||
)}
|
||||
{/* Blinking cursor — hidden when finished */}
|
||||
{!isFinished && (
|
||||
<span
|
||||
className="inline-block w-[6px] h-[11px] align-middle ml-0.5"
|
||||
style={{
|
||||
background: color,
|
||||
animation: "cursorBlink 1s step-end infinite",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Tool overlay */}
|
||||
<ToolOverlay
|
||||
toolName={activeToolName}
|
||||
color={color}
|
||||
visible={toolRunning}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const ParallelSubagentBubble = memo(
|
||||
function ParallelSubagentBubble({ groups }: ParallelSubagentBubbleProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [zoomedIdx, setZoomedIdx] = useState<number | null>(null);
|
||||
|
||||
// Labels with instance numbers for duplicates
|
||||
const labels: string[] = (() => {
|
||||
const countByBase = new Map<string, number>();
|
||||
const bases = groups.map((g) => subagentLabel(g.nodeId));
|
||||
for (const b of bases)
|
||||
countByBase.set(b, (countByBase.get(b) ?? 0) + 1);
|
||||
const idxByBase = new Map<string, number>();
|
||||
return bases.map((b) => {
|
||||
if ((countByBase.get(b) ?? 1) <= 1) return b;
|
||||
const idx = (idxByBase.get(b) ?? 0) + 1;
|
||||
idxByBase.set(b, idx);
|
||||
return `${b} #${idx}`;
|
||||
});
|
||||
})();
|
||||
|
||||
// Latest-active pane
|
||||
const latestIdx = groups.reduce<number>((best, g, i) => {
|
||||
const filtered = g.messages.filter((m) => m.type !== "tool_status");
|
||||
const lm = last(filtered);
|
||||
if (!lm) return best;
|
||||
if (best < 0) return i;
|
||||
const bm = last(
|
||||
groups[best].messages.filter((m) => m.type !== "tool_status")
|
||||
);
|
||||
if (!bm) return i;
|
||||
return (lm.createdAt ?? 0) >= (bm.createdAt ?? 0) ? i : best;
|
||||
}, -1);
|
||||
|
||||
// Per-group finished detection (same logic as MuxPane)
|
||||
const finishedFlags = groups.map((g) => {
|
||||
const lt = last(g.messages.filter((m) => m.type === "tool_status"));
|
||||
if (!lt) return false;
|
||||
try {
|
||||
const p = JSON.parse(lt.content);
|
||||
const tools: { name: string; done: boolean }[] = p.tools || [];
|
||||
if (!p.allDone || tools.length === 0) return false;
|
||||
return tools.some(
|
||||
(t) => t.done && (t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
} catch { return false; }
|
||||
});
|
||||
const activeCount = finishedFlags.filter((f) => !f).length;
|
||||
|
||||
if (groups.length === 0) return null;
|
||||
|
||||
// Grid sizing: 2 columns, auto rows capped at a fixed height
|
||||
const rows = Math.ceil(groups.length / 2);
|
||||
const gridHeight = expanded
|
||||
? Math.min(rows * 200, 480)
|
||||
: Math.min(rows * 100, 240);
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
{/* Left icon */}
|
||||
<div
|
||||
className="flex-shrink-0 w-7 h-7 rounded-xl flex items-center justify-center mt-1"
|
||||
style={{
|
||||
backgroundColor: `${workerColor}18`,
|
||||
border: `1.5px solid ${workerColor}35`,
|
||||
}}
|
||||
>
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color: workerColor }} />
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0 max-w-[90%]">
|
||||
{/* Header */}
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="font-medium text-xs" style={{ color: workerColor }}>
|
||||
{groups.length === 1 ? "Sub-agent" : "Parallel Agents"}
|
||||
</span>
|
||||
<span className="text-[10px] font-medium px-1.5 py-0.5 rounded-md bg-muted text-muted-foreground">
|
||||
{activeCount > 0 ? `${activeCount} running` : `${groups.length} done`}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => {
|
||||
setExpanded((v) => !v);
|
||||
setZoomedIdx(null);
|
||||
}}
|
||||
className="ml-auto text-muted-foreground/60 hover:text-muted-foreground transition-colors p-0.5 rounded"
|
||||
title={expanded ? "Collapse" : "Expand"}
|
||||
>
|
||||
{expanded ? (
|
||||
<ChevronUp className="w-3.5 h-3.5" />
|
||||
) : (
|
||||
<ChevronDown className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Mux frame */}
|
||||
<div
|
||||
className="rounded-lg overflow-hidden"
|
||||
style={{
|
||||
border: "2px solid #1a1a2a",
|
||||
background: "#08080e",
|
||||
}}
|
||||
>
|
||||
{/* Grid */}
|
||||
<div
|
||||
className="grid gap-px"
|
||||
style={{
|
||||
gridTemplateColumns:
|
||||
groups.length === 1 ? "1fr" : "1fr 1fr",
|
||||
gridTemplateRows: `repeat(${rows}, 1fr)`,
|
||||
height: gridHeight,
|
||||
background: "#111",
|
||||
}}
|
||||
>
|
||||
{groups.map((group, i) => (
|
||||
<MuxPane
|
||||
key={group.nodeId}
|
||||
group={group}
|
||||
index={i}
|
||||
label={labels[i]}
|
||||
isFocused={latestIdx === i}
|
||||
isZoomed={zoomedIdx === i}
|
||||
onClickTitle={() =>
|
||||
setZoomedIdx(zoomedIdx === i ? null : i)
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
(prev, next) =>
|
||||
prev.groupId === next.groupId &&
|
||||
prev.groups.length === next.groups.length &&
|
||||
prev.groups.every(
|
||||
(g, i) =>
|
||||
g.nodeId === next.groups[i].nodeId &&
|
||||
g.messages.length === next.groups[i].messages.length &&
|
||||
last(g.messages)?.content === last(next.groups[i].messages)?.content &&
|
||||
g.contextUsage?.usagePct === next.groups[i].contextUsage?.usagePct
|
||||
)
|
||||
);
|
||||
|
||||
export default ParallelSubagentBubble;
|
||||
|
||||
// Injected as a global style (keyframes can't be inline)
|
||||
if (typeof document !== "undefined") {
|
||||
const id = "parallel-subagent-keyframes";
|
||||
if (!document.getElementById(id)) {
|
||||
const style = document.createElement("style");
|
||||
style.id = id;
|
||||
style.textContent = `
|
||||
@keyframes cursorBlink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
|
||||
@keyframes thermoPulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.4; } }
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ export function sseEventToChatMessage(
|
||||
const innerSuffix = innerTurn != null && innerTurn > 0 ? `-t${innerTurn}` : "";
|
||||
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
if (!snapshot) return null;
|
||||
if (!snapshot.trim()) return null;
|
||||
return {
|
||||
id: `stream-${iterIdKey}${innerSuffix}-${event.node_id}`,
|
||||
agent: agentDisplayName || event.node_id || "Agent",
|
||||
@@ -72,6 +72,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -100,7 +102,7 @@ export function sseEventToChatMessage(
|
||||
const llmInnerSuffix = llmInnerTurn != null && llmInnerTurn > 0 ? `-t${llmInnerTurn}` : "";
|
||||
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
if (!snapshot) return null;
|
||||
if (!snapshot.trim()) return null;
|
||||
return {
|
||||
id: `stream-${idKey}${llmInnerSuffix}-${event.node_id}`,
|
||||
agent: event.node_id || "Agent",
|
||||
@@ -110,6 +112,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
// ── Shared graph utilities ──
|
||||
// Common helpers used by both AgentGraph and DraftGraph.
|
||||
// AgentGraph still has its own copies for now (separate cleanup PR).
|
||||
|
||||
/** Read a CSS custom property value (space-separated HSL components). */
|
||||
export function cssVar(name: string): string {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
||||
}
|
||||
|
||||
/** Truncate label to fit within `availablePx` at the given fontSize. */
|
||||
export function truncateLabel(label: string, availablePx: number, fontSize: number): string {
|
||||
const avgCharW = fontSize * 0.58;
|
||||
const maxChars = Math.floor(availablePx / avgCharW);
|
||||
if (label.length <= maxChars) return label;
|
||||
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
|
||||
}
|
||||
|
||||
// ── Trigger styling ──
|
||||
|
||||
export type TriggerColorSet = { bg: string; border: string; text: string; icon: string };
|
||||
|
||||
export function buildTriggerColors(): TriggerColorSet {
|
||||
const bg = cssVar("--trigger-bg") || "210 25% 14%";
|
||||
const border = cssVar("--trigger-border") || "210 30% 30%";
|
||||
const text = cssVar("--trigger-text") || "210 30% 65%";
|
||||
const icon = cssVar("--trigger-icon") || "210 40% 55%";
|
||||
return {
|
||||
bg: `hsl(${bg})`,
|
||||
border: `hsl(${border})`,
|
||||
text: `hsl(${text})`,
|
||||
icon: `hsl(${icon})`,
|
||||
};
|
||||
}
|
||||
|
||||
export const ACTIVE_TRIGGER_COLORS: TriggerColorSet = {
|
||||
bg: "hsl(210,30%,18%)",
|
||||
border: "hsl(210,50%,50%)",
|
||||
text: "hsl(210,40%,75%)",
|
||||
icon: "hsl(210,60%,65%)",
|
||||
};
|
||||
|
||||
export const TRIGGER_ICONS: Record<string, string> = {
|
||||
webhook: "\u26A1", // lightning bolt
|
||||
timer: "\u23F1", // stopwatch
|
||||
api: "\u2192", // right arrow
|
||||
event: "\u223F", // sine wave
|
||||
};
|
||||
|
||||
/** Format a cron expression into a human-readable schedule label. */
|
||||
export function cronToLabel(cron: string): string {
|
||||
const parts = cron.trim().split(/\s+/);
|
||||
if (parts.length !== 5) return cron;
|
||||
const [min, hour, dom, mon, dow] = parts;
|
||||
|
||||
// */N * * * * -> "Every Nm"
|
||||
if (min.startsWith("*/") && hour === "*" && dom === "*" && mon === "*" && dow === "*") {
|
||||
return `Every ${min.slice(2)}m`;
|
||||
}
|
||||
// 0 */N * * * -> "Every Nh"
|
||||
if (min === "0" && hour.startsWith("*/") && dom === "*" && mon === "*" && dow === "*") {
|
||||
return `Every ${hour.slice(2)}h`;
|
||||
}
|
||||
// 0 H * * * -> "Daily at Ham/pm"
|
||||
if (dom === "*" && mon === "*" && dow === "*" && !min.includes("*") && !hour.includes("*")) {
|
||||
const h = parseInt(hour, 10);
|
||||
const m = parseInt(min, 10);
|
||||
const suffix = h >= 12 ? "PM" : "AM";
|
||||
const h12 = h % 12 || 12;
|
||||
return m === 0 ? `Daily at ${h12}${suffix}` : `Daily at ${h12}:${String(m).padStart(2, "0")}${suffix}`;
|
||||
}
|
||||
return cron;
|
||||
}
|
||||
|
||||
/** Theme-reactive hook for inactive trigger colors. */
|
||||
export function useTriggerColors(): TriggerColorSet {
|
||||
const [colors, setColors] = useState<TriggerColorSet>(buildTriggerColors);
|
||||
|
||||
useEffect(() => {
|
||||
const rebuild = () => setColors(buildTriggerColors());
|
||||
const obs = new MutationObserver(rebuild);
|
||||
obs.observe(document.documentElement, { attributes: true, attributeFilter: ["class", "style"] });
|
||||
return () => obs.disconnect();
|
||||
}, []);
|
||||
|
||||
return colors;
|
||||
}
|
||||
@@ -27,7 +27,14 @@ export default function MyAgents() {
|
||||
agentsApi
|
||||
.discover()
|
||||
.then((result) => {
|
||||
setAgents(result["Your Agents"] || []);
|
||||
const entries = result["Your Agents"] || [];
|
||||
entries.sort((a, b) => {
|
||||
if (!a.last_active && !b.last_active) return 0;
|
||||
if (!a.last_active) return 1;
|
||||
if (!b.last_active) return -1;
|
||||
return b.last_active.localeCompare(a.last_active);
|
||||
});
|
||||
setAgents(entries);
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err.message || "Failed to load agents");
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
|
||||
import ReactDOM from "react-dom";
|
||||
import { useSearchParams, useNavigate } from "react-router-dom";
|
||||
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X } from "lucide-react";
|
||||
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X, FolderOpen } from "lucide-react";
|
||||
import type { GraphNode, NodeStatus } from "@/components/graph-types";
|
||||
import DraftGraph from "@/components/DraftGraph";
|
||||
import ChatPanel, { type ChatMessage } from "@/components/ChatPanel";
|
||||
@@ -17,6 +17,7 @@ import { useMultiSSE } from "@/hooks/use-sse";
|
||||
import type { LiveSession, AgentEvent, DiscoverEntry, NodeSpec, DraftGraph as DraftGraphData } from "@/api/types";
|
||||
import { sseEventToChatMessage, formatAgentDisplayName } from "@/lib/chat-helpers";
|
||||
import { topologyToGraphNodes } from "@/lib/graph-converter";
|
||||
import { cronToLabel } from "@/lib/graphUtils";
|
||||
import { ApiError } from "@/api/client";
|
||||
|
||||
const makeId = () => Math.random().toString(36).slice(2, 9);
|
||||
@@ -251,6 +252,10 @@ function truncate(s: string, max: number): string {
|
||||
type SessionRestoreResult = {
|
||||
messages: ChatMessage[];
|
||||
restoredPhase: "planning" | "building" | "staging" | "running" | null;
|
||||
/** Last flowchart map from events — used to restore flowchart overlay on cold resume. */
|
||||
flowchartMap: Record<string, string[]> | null;
|
||||
/** Last original draft from events — used to restore flowchart overlay on cold resume. */
|
||||
originalDraft: DraftGraphData | null;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -267,6 +272,8 @@ async function restoreSessionMessages(
|
||||
if (events.length > 0) {
|
||||
const messages: ChatMessage[] = [];
|
||||
let runningPhase: ChatMessage["phase"] = undefined;
|
||||
let flowchartMap: Record<string, string[]> | null = null;
|
||||
let originalDraft: DraftGraphData | null = null;
|
||||
for (const evt of events) {
|
||||
// Track phase transitions so each message gets the phase it was created in
|
||||
const p = evt.type === "queen_phase_changed" ? evt.data?.phase as string
|
||||
@@ -275,6 +282,12 @@ async function restoreSessionMessages(
|
||||
if (p && ["planning", "building", "staging", "running"].includes(p)) {
|
||||
runningPhase = p as ChatMessage["phase"];
|
||||
}
|
||||
// Track last flowchart state for cold restore
|
||||
if (evt.type === "flowchart_map_updated" && evt.data) {
|
||||
const mapData = evt.data as { map?: Record<string, string[]>; original_draft?: DraftGraphData };
|
||||
flowchartMap = mapData.map ?? null;
|
||||
originalDraft = mapData.original_draft ?? null;
|
||||
}
|
||||
const msg = sseEventToChatMessage(evt, thread, agentDisplayName);
|
||||
if (!msg) continue;
|
||||
if (evt.stream_id === "queen") {
|
||||
@@ -283,12 +296,12 @@ async function restoreSessionMessages(
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
return { messages, restoredPhase: runningPhase ?? null };
|
||||
return { messages, restoredPhase: runningPhase ?? null, flowchartMap, originalDraft };
|
||||
}
|
||||
} catch {
|
||||
// Event log not available — session will start fresh.
|
||||
}
|
||||
return { messages: [], restoredPhase: null };
|
||||
return { messages: [], restoredPhase: null, flowchartMap: null, originalDraft: null };
|
||||
}
|
||||
|
||||
// --- Per-agent backend state (consolidated) ---
|
||||
@@ -339,6 +352,10 @@ interface AgentBackendState {
|
||||
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
|
||||
/** Whether the pending question came from queen or worker */
|
||||
pendingQuestionSource: "queen" | "worker" | null;
|
||||
/** Per-node context window usage (from context_usage_updated events) */
|
||||
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
|
||||
/** Whether the queen's LLM supports image content — false disables the attach button */
|
||||
queenSupportsImages: boolean;
|
||||
}
|
||||
|
||||
function defaultAgentState(): AgentBackendState {
|
||||
@@ -376,6 +393,8 @@ function defaultAgentState(): AgentBackendState {
|
||||
pendingOptions: null,
|
||||
pendingQuestions: null,
|
||||
pendingQuestionSource: null,
|
||||
contextUsage: {},
|
||||
queenSupportsImages: true,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -557,7 +576,11 @@ export default function Workspace() {
|
||||
const [dismissedBanner, setDismissedBanner] = useState<string | null>(null);
|
||||
const [selectedNode, setSelectedNode] = useState<GraphNode | null>(null);
|
||||
const [triggerTaskDraft, setTriggerTaskDraft] = useState("");
|
||||
const [triggerCronDraft, setTriggerCronDraft] = useState("");
|
||||
const [triggerTaskSaving, setTriggerTaskSaving] = useState(false);
|
||||
const [triggerScheduleSaving, setTriggerScheduleSaving] = useState(false);
|
||||
const [triggerCronSaved, setTriggerCronSaved] = useState(false);
|
||||
const [triggerTaskSaved, setTriggerTaskSaved] = useState(false);
|
||||
const [newTabOpen, setNewTabOpen] = useState(false);
|
||||
const newTabBtnRef = useRef<HTMLButtonElement>(null);
|
||||
const [graphPanelPct, setGraphPanelPct] = useState(30);
|
||||
@@ -613,6 +636,10 @@ export default function Workspace() {
|
||||
// it was created in (avoids stale-closure when phase change and message
|
||||
// events arrive in the same React batch).
|
||||
const queenPhaseRef = useRef<Record<string, string>>({});
|
||||
// Accumulated queen text across inner_turns within the same iteration.
|
||||
// Key: `${agentType}:${execution_id}:${iteration}`, value: { [inner_turn]: snapshot }.
|
||||
// This lets us merge all inner_turn text into one chat bubble per iteration.
|
||||
const queenIterTextRef = useRef<Record<string, Record<number, string>>>({});
|
||||
// Timestamp when designingDraft was set — used to enforce minimum spinner duration.
|
||||
const designingDraftSinceRef = useRef<Record<string, number>>({});
|
||||
const designingDraftTimerRef = useRef<Record<string, ReturnType<typeof setTimeout>>>({});
|
||||
@@ -794,6 +821,8 @@ export default function Workspace() {
|
||||
}
|
||||
|
||||
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
|
||||
let restoredFlowchartMap: Record<string, string[]> | null = null;
|
||||
let restoredOriginalDraft: DraftGraphData | null = null;
|
||||
if (!liveSession) {
|
||||
// Fetch conversation history from disk BEFORE creating the new session.
|
||||
// SKIP if messages were already pre-populated by handleHistoryOpen.
|
||||
@@ -805,9 +834,22 @@ export default function Workspace() {
|
||||
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
|
||||
preRestoredMsgs.push(...restored.messages);
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not available — will start fresh
|
||||
}
|
||||
} else if (restoreFrom && alreadyHasMessages) {
|
||||
// Messages already cached in localStorage — still fetch events for
|
||||
// non-message state (phase, flowchart) that isn't cached.
|
||||
try {
|
||||
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not critical — UI will still show cached messages
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress the queen's intro cycle whenever we are about to restore a
|
||||
@@ -830,7 +872,7 @@ export default function Workspace() {
|
||||
}));
|
||||
}
|
||||
restoredMessageCount = preRestoredMsgs.length;
|
||||
} else if (restoreFrom && activeId) {
|
||||
} else if (restoreFrom && activeId && !alreadyHasMessages) {
|
||||
// We had a stored session but no messages on disk — wipe stale localStorage cache
|
||||
setSessionsByAgent(prev => ({
|
||||
...prev,
|
||||
@@ -884,6 +926,10 @@ export default function Workspace() {
|
||||
queenReady: true,
|
||||
queenPhase: qPhase,
|
||||
queenBuilding: qPhase === "building",
|
||||
queenSupportsImages: liveSession.queen_supports_images !== false,
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
@@ -958,6 +1004,8 @@ export default function Workspace() {
|
||||
|
||||
// Track the last queen phase seen in the event log for cold restore
|
||||
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
|
||||
let restoredFlowchartMap: Record<string, string[]> | null = null;
|
||||
let restoredOriginalDraft: DraftGraphData | null = null;
|
||||
|
||||
if (!liveSession) {
|
||||
// Reconnect failed — clear stale cached messages from localStorage restore.
|
||||
@@ -985,6 +1033,19 @@ export default function Workspace() {
|
||||
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
|
||||
preQueenMsgs = restored.messages;
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} else if (coldRestoreId && alreadyHasMessages) {
|
||||
// Messages already cached — still fetch events for non-message state (phase, flowchart)
|
||||
try {
|
||||
const displayNameTemp = formatAgentDisplayName(agentPath);
|
||||
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not critical — UI will still show cached messages
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress intro whenever we are about to restore a previous conversation.
|
||||
@@ -1065,6 +1126,10 @@ export default function Workspace() {
|
||||
displayName,
|
||||
queenPhase: initialPhase,
|
||||
queenBuilding: initialPhase === "building",
|
||||
queenSupportsImages: session.queen_supports_images !== false,
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
|
||||
// Update the session label + backendSessionId. Also set historySourceId
|
||||
@@ -1102,6 +1167,11 @@ export default function Workspace() {
|
||||
if (historyId && !coldRestoreId) {
|
||||
const restored = await restoreSessionMessages(historyId, agentType, displayName);
|
||||
restoredMsgs.push(...restored.messages);
|
||||
// Use flowchart from event log if not already set
|
||||
if (restored.flowchartMap && !restoredFlowchartMap) {
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
}
|
||||
|
||||
// Check worker status (needed for isWorkerRunning flag)
|
||||
try {
|
||||
@@ -1144,6 +1214,9 @@ export default function Workspace() {
|
||||
loading: false,
|
||||
queenReady: !!(isResumedSession || hasRestoredContent),
|
||||
...(isWorkerRunning ? { workerRunState: "running" } : {}),
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
@@ -1260,12 +1333,28 @@ export default function Workspace() {
|
||||
|
||||
const fireMap = new Map<string, number>();
|
||||
const taskMap = new Map<string, string>();
|
||||
const labelMap = new Map<string, string>();
|
||||
const targetMap = new Map<string, string>();
|
||||
for (const ep of triggerEps) {
|
||||
const nodeId = `__trigger_${ep.id}`;
|
||||
if (ep.next_fire_in != null) {
|
||||
fireMap.set(`__trigger_${ep.id}`, ep.next_fire_in);
|
||||
fireMap.set(nodeId, ep.next_fire_in);
|
||||
}
|
||||
if (ep.task != null) {
|
||||
taskMap.set(`__trigger_${ep.id}`, ep.task);
|
||||
taskMap.set(nodeId, ep.task);
|
||||
}
|
||||
const cron = ep.trigger_config?.cron as string | undefined;
|
||||
const interval = ep.trigger_config?.interval_minutes as number | undefined;
|
||||
const epLabel = cron
|
||||
? cronToLabel(cron)
|
||||
: interval
|
||||
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
|
||||
: ep.name || undefined;
|
||||
if (epLabel) {
|
||||
labelMap.set(nodeId, epLabel);
|
||||
}
|
||||
if (ep.entry_node) {
|
||||
targetMap.set(nodeId, ep.entry_node);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1274,14 +1363,18 @@ export default function Workspace() {
|
||||
if (!ss?.length) return prev;
|
||||
const existingIds = new Set(ss[0].graphNodes.map(n => n.id));
|
||||
|
||||
// Update existing trigger nodes
|
||||
// Update existing trigger nodes (countdown, task, label, target)
|
||||
let updated = ss[0].graphNodes.map((n) => {
|
||||
if (n.nodeType !== "trigger") return n;
|
||||
const nfi = fireMap.get(n.id);
|
||||
const task = taskMap.get(n.id);
|
||||
if (nfi == null && task == null) return n;
|
||||
const label = labelMap.get(n.id);
|
||||
const target = targetMap.get(n.id);
|
||||
if (nfi == null && task == null && !label && !target) return n;
|
||||
return {
|
||||
...n,
|
||||
...(label && label !== n.label ? { label } : {}),
|
||||
...(target ? { next: [target] } : {}),
|
||||
triggerConfig: {
|
||||
...n.triggerConfig,
|
||||
...(nfi != null ? { next_fire_in: nfi } : {}),
|
||||
@@ -1291,14 +1384,15 @@ export default function Workspace() {
|
||||
});
|
||||
|
||||
// Discover new triggers not yet in the graph
|
||||
const entryNode = ss[0].graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const fallbackEntry = ss[0].graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const newNodes: GraphNode[] = [];
|
||||
for (const ep of triggerEps) {
|
||||
const nodeId = `__trigger_${ep.id}`;
|
||||
if (existingIds.has(nodeId)) continue;
|
||||
const target = ep.entry_node || fallbackEntry;
|
||||
newNodes.push({
|
||||
id: nodeId,
|
||||
label: ep.name || ep.id,
|
||||
label: labelMap.get(nodeId) || ep.name || ep.id,
|
||||
status: "pending",
|
||||
nodeType: "trigger",
|
||||
triggerType: ep.trigger_type,
|
||||
@@ -1307,7 +1401,7 @@ export default function Workspace() {
|
||||
...(ep.next_fire_in != null ? { next_fire_in: ep.next_fire_in } : {}),
|
||||
...(ep.task ? { task: ep.task } : {}),
|
||||
},
|
||||
...(entryNode ? { next: [entryNode] } : {}),
|
||||
...(target ? { next: [target] } : {}),
|
||||
});
|
||||
}
|
||||
if (newNodes.length > 0) {
|
||||
@@ -1625,14 +1719,29 @@ export default function Workspace() {
|
||||
if (isQueen) console.log('[QUEEN] chatMsg:', chatMsg?.id, chatMsg?.content?.slice(0, 50), 'turn:', currentTurn);
|
||||
if (chatMsg && !suppressQueenMessages) {
|
||||
// Queen emits multiple client_output_delta / llm_text_delta snapshots
|
||||
// across iterations and inner tool-loop turns. Build a stable ID that
|
||||
// groups streaming deltas for the *same* output (same execution +
|
||||
// iteration + inner_turn) into one bubble, while keeping distinct
|
||||
// outputs as separate bubbles so earlier text isn't overwritten.
|
||||
// across iterations and inner tool-loop turns. Merge all inner_turns
|
||||
// within the same iteration into ONE bubble so the queen's multi-step
|
||||
// tool loop (text → tool → text → tool → text) appears as one cohesive
|
||||
// message rather than many small fragments.
|
||||
if (isQueen && (event.type === "client_output_delta" || event.type === "llm_text_delta") && event.execution_id) {
|
||||
const iter = event.data?.iteration ?? 0;
|
||||
const inner = event.data?.inner_turn ?? 0;
|
||||
chatMsg.id = `queen-stream-${event.execution_id}-${iter}-${inner}`;
|
||||
const inner = (event.data?.inner_turn as number) ?? 0;
|
||||
const iterKey = `${agentType}:${event.execution_id}:${iter}`;
|
||||
|
||||
// Store the latest snapshot for this inner_turn
|
||||
if (!queenIterTextRef.current[iterKey]) {
|
||||
queenIterTextRef.current[iterKey] = {};
|
||||
}
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
queenIterTextRef.current[iterKey][inner] = snapshot;
|
||||
|
||||
// Concatenate all inner_turn snapshots in order
|
||||
const parts = queenIterTextRef.current[iterKey];
|
||||
const sortedInners = Object.keys(parts).map(Number).sort((a, b) => a - b);
|
||||
chatMsg.content = sortedInners.map(k => parts[k]).join("\n");
|
||||
|
||||
// Single ID per iteration — no inner_turn in the ID
|
||||
chatMsg.id = `queen-stream-${event.execution_id}-${iter}`;
|
||||
}
|
||||
if (isQueen) {
|
||||
chatMsg.role = role;
|
||||
@@ -1907,6 +2016,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -1978,6 +2089,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -2054,6 +2167,29 @@ export default function Workspace() {
|
||||
}
|
||||
break;
|
||||
|
||||
case "context_usage_updated": {
|
||||
const streamKey = isQueen ? "__queen__" : (event.node_id || streamId);
|
||||
const usagePct = (event.data?.usage_pct as number) ?? 0;
|
||||
const messageCount = (event.data?.message_count as number) ?? 0;
|
||||
const estimatedTokens = (event.data?.estimated_tokens as number) ?? 0;
|
||||
const maxTokens = (event.data?.max_context_tokens as number) ?? 0;
|
||||
setAgentStates(prev => {
|
||||
const state = prev[agentType];
|
||||
if (!state) return prev;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: {
|
||||
...state,
|
||||
contextUsage: {
|
||||
...state.contextUsage,
|
||||
[streamKey]: { usagePct, messageCount, estimatedTokens, maxTokens },
|
||||
},
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "node_action_plan":
|
||||
if (!isQueen && event.node_id) {
|
||||
const plan = (event.data?.plan as string) || "";
|
||||
@@ -2237,10 +2373,18 @@ export default function Workspace() {
|
||||
// Synthesize new trigger node at the front of the graph
|
||||
const triggerType = (event.data?.trigger_type as string) || "timer";
|
||||
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
|
||||
const entryNode = s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const entryNode = (event.data?.entry_node as string) || s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const triggerName = (event.data?.name as string) || triggerId;
|
||||
const _cron = triggerConfig.cron as string | undefined;
|
||||
const _interval = triggerConfig.interval_minutes as number | undefined;
|
||||
const computedLabel = _cron
|
||||
? cronToLabel(_cron)
|
||||
: _interval
|
||||
? `Every ${_interval >= 60 ? `${_interval / 60}h` : `${_interval}m`}`
|
||||
: triggerName;
|
||||
const newNode: GraphNode = {
|
||||
id: nodeId,
|
||||
label: triggerId,
|
||||
label: computedLabel,
|
||||
status: "running",
|
||||
nodeType: "trigger",
|
||||
triggerType,
|
||||
@@ -2305,10 +2449,18 @@ export default function Workspace() {
|
||||
if (s.graphNodes.some(n => n.id === nodeId)) return s;
|
||||
const triggerType = (event.data?.trigger_type as string) || "timer";
|
||||
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
|
||||
const entryNode = s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const entryNode = (event.data?.entry_node as string) || s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const triggerName = (event.data?.name as string) || triggerId;
|
||||
const _cron2 = triggerConfig.cron as string | undefined;
|
||||
const _interval2 = triggerConfig.interval_minutes as number | undefined;
|
||||
const computedLabel2 = _cron2
|
||||
? cronToLabel(_cron2)
|
||||
: _interval2
|
||||
? `Every ${_interval2 >= 60 ? `${_interval2 / 60}h` : `${_interval2}m`}`
|
||||
: triggerName;
|
||||
const newNode: GraphNode = {
|
||||
id: nodeId,
|
||||
label: triggerId,
|
||||
label: computedLabel2,
|
||||
status: "pending",
|
||||
nodeType: "trigger",
|
||||
triggerType,
|
||||
@@ -2323,6 +2475,43 @@ export default function Workspace() {
|
||||
break;
|
||||
}
|
||||
|
||||
case "trigger_updated": {
|
||||
const triggerId = event.data?.trigger_id as string;
|
||||
if (triggerId) {
|
||||
const nodeId = `__trigger_${triggerId}`;
|
||||
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
|
||||
const cron = triggerConfig.cron as string | undefined;
|
||||
const interval = triggerConfig.interval_minutes as number | undefined;
|
||||
const newLabel = cron
|
||||
? cronToLabel(cron)
|
||||
: interval
|
||||
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
|
||||
: undefined;
|
||||
setSessionsByAgent(prev => {
|
||||
const sessions = prev[agentType] || [];
|
||||
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: sessions.map(s => {
|
||||
if (s.id !== activeId) return s;
|
||||
return {
|
||||
...s,
|
||||
graphNodes: s.graphNodes.map(n => {
|
||||
if (n.id !== nodeId) return n;
|
||||
return {
|
||||
...n,
|
||||
...(newLabel ? { label: newLabel } : {}),
|
||||
triggerConfig: { ...n.triggerConfig, ...triggerConfig },
|
||||
};
|
||||
}),
|
||||
};
|
||||
}),
|
||||
};
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "trigger_removed": {
|
||||
const triggerId = event.data?.trigger_id as string;
|
||||
if (triggerId) {
|
||||
@@ -2376,14 +2565,43 @@ export default function Workspace() {
|
||||
const liveSelectedNode = selectedNode && currentGraph.nodes.find(n => n.id === selectedNode.id);
|
||||
const resolvedSelectedNode = liveSelectedNode || selectedNode;
|
||||
|
||||
// Sync trigger task draft when selected trigger node changes
|
||||
// Sync trigger drafts when selected trigger node changes
|
||||
useEffect(() => {
|
||||
if (resolvedSelectedNode?.nodeType === "trigger") {
|
||||
const tc = resolvedSelectedNode.triggerConfig as Record<string, unknown> | undefined;
|
||||
setTriggerTaskDraft((tc?.task as string) || "");
|
||||
setTriggerCronDraft((tc?.cron as string) || "");
|
||||
}
|
||||
}, [resolvedSelectedNode?.id]);
|
||||
|
||||
const patchTriggerNode = useCallback((agentType: string, triggerNodeId: string, patch: { task?: string; trigger_config?: Record<string, unknown>; label?: string }) => {
|
||||
setSessionsByAgent(prev => {
|
||||
const sessions = prev[agentType] || [];
|
||||
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: sessions.map(s => {
|
||||
if (s.id !== activeId) return s;
|
||||
return {
|
||||
...s,
|
||||
graphNodes: s.graphNodes.map(n => {
|
||||
if (n.id !== triggerNodeId) return n;
|
||||
return {
|
||||
...n,
|
||||
...(patch.label !== undefined ? { label: patch.label } : {}),
|
||||
triggerConfig: {
|
||||
...n.triggerConfig,
|
||||
...(patch.trigger_config || {}),
|
||||
...(patch.task !== undefined ? { task: patch.task } : {}),
|
||||
},
|
||||
};
|
||||
}),
|
||||
};
|
||||
}),
|
||||
};
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Build a flat list of all agent-type tabs for the tab bar
|
||||
const agentTabs = Object.entries(sessionsByAgent)
|
||||
.filter(([, sessions]) => sessions.length > 0)
|
||||
@@ -2400,7 +2618,7 @@ export default function Workspace() {
|
||||
});
|
||||
|
||||
// --- handleSend ---
|
||||
const handleSend = useCallback((text: string, thread: string) => {
|
||||
const handleSend = useCallback((text: string, thread: string, images?: import("@/components/ChatPanel").ImageContent[]) => {
|
||||
if (!activeSession) return;
|
||||
const state = agentStates[activeWorker];
|
||||
|
||||
@@ -2466,6 +2684,7 @@ export default function Workspace() {
|
||||
const userMsg: ChatMessage = {
|
||||
id: makeId(), agent: "You", agentColor: "",
|
||||
content: text, timestamp: "", type: "user", thread, createdAt: Date.now(),
|
||||
images,
|
||||
};
|
||||
setSessionsByAgent(prev => ({
|
||||
...prev,
|
||||
@@ -2477,7 +2696,7 @@ export default function Workspace() {
|
||||
updateAgentState(activeWorker, { isTyping: true, queenIsTyping: true });
|
||||
|
||||
if (state?.sessionId && state?.ready) {
|
||||
executionApi.chat(state.sessionId, text).catch((err: unknown) => {
|
||||
executionApi.chat(state.sessionId, text, images).catch((err: unknown) => {
|
||||
const errMsg = err instanceof Error ? err.message : String(err);
|
||||
const errorChatMsg: ChatMessage = {
|
||||
id: makeId(), agent: "System", agentColor: "",
|
||||
@@ -2893,6 +3112,16 @@ export default function Workspace() {
|
||||
<KeyRound className="w-3.5 h-3.5" />
|
||||
Credentials
|
||||
</button>
|
||||
{activeAgentState?.sessionId && (
|
||||
<button
|
||||
onClick={() => sessionsApi.revealFolder(activeAgentState.sessionId!).catch(() => {})}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/50 transition-colors flex-shrink-0"
|
||||
title="Open session data folder"
|
||||
>
|
||||
<FolderOpen className="w-3.5 h-3.5" />
|
||||
Data
|
||||
</button>
|
||||
)}
|
||||
</TopBar>
|
||||
|
||||
{/* Main content area */}
|
||||
@@ -3010,6 +3239,8 @@ export default function Workspace() {
|
||||
}
|
||||
onMultiQuestionSubmit={handleMultiQuestionAnswer}
|
||||
onQuestionDismiss={handleQuestionDismiss}
|
||||
contextUsage={activeAgentState?.contextUsage}
|
||||
supportsImages={activeAgentState?.queenSupportsImages ?? true}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -3052,18 +3283,64 @@ export default function Workspace() {
|
||||
const interval = tc?.interval_minutes as number | undefined;
|
||||
const eventTypes = tc?.event_types as string[] | undefined;
|
||||
const scheduleLabel = cron
|
||||
? `cron: ${cron}`
|
||||
? cronToLabel(cron)
|
||||
: interval
|
||||
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
|
||||
: eventTypes?.length
|
||||
? eventTypes.join(", ")
|
||||
: null;
|
||||
return scheduleLabel ? (
|
||||
const canEditCron = resolvedSelectedNode.triggerType === "timer";
|
||||
const cronChanged = canEditCron && triggerCronDraft.trim() !== (cron || "");
|
||||
return scheduleLabel || canEditCron ? (
|
||||
<div>
|
||||
<p className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider mb-1.5">Schedule</p>
|
||||
<p className="text-xs text-foreground/80 font-mono bg-muted/30 rounded-lg px-3 py-2 border border-border/20">
|
||||
{scheduleLabel}
|
||||
</p>
|
||||
{scheduleLabel && (
|
||||
<p className="text-xs text-foreground/80 font-mono bg-muted/30 rounded-lg px-3 py-2 border border-border/20">
|
||||
{scheduleLabel}
|
||||
</p>
|
||||
)}
|
||||
{canEditCron && (
|
||||
<>
|
||||
<input
|
||||
value={triggerCronDraft}
|
||||
onChange={(e) => setTriggerCronDraft(e.target.value)}
|
||||
placeholder="0 5 * * *"
|
||||
className="mt-1.5 w-full text-xs text-foreground/80 bg-muted/30 rounded-lg px-3 py-2 border border-border/20 font-mono focus:outline-none focus:border-primary/40"
|
||||
/>
|
||||
<p className="text-[10px] text-muted-foreground/60 mt-1">
|
||||
Edit the cron expression for this timer trigger.
|
||||
</p>
|
||||
{(cronChanged || triggerCronSaved) && (
|
||||
<button
|
||||
disabled={triggerScheduleSaving || !cronChanged}
|
||||
onClick={async () => {
|
||||
const sessionId = activeAgentState?.sessionId;
|
||||
const triggerId = resolvedSelectedNode.id.replace("__trigger_", "");
|
||||
const nextCron = triggerCronDraft.trim();
|
||||
if (!sessionId || !nextCron) return;
|
||||
const nextTriggerConfig: Record<string, unknown> = { cron: nextCron };
|
||||
setTriggerScheduleSaving(true);
|
||||
try {
|
||||
await sessionsApi.updateTrigger(sessionId, triggerId, {
|
||||
trigger_config: nextTriggerConfig,
|
||||
});
|
||||
patchTriggerNode(activeWorker, resolvedSelectedNode.id, {
|
||||
trigger_config: nextTriggerConfig,
|
||||
label: cronToLabel(nextCron),
|
||||
});
|
||||
setTriggerCronSaved(true);
|
||||
setTimeout(() => setTriggerCronSaved(false), 2000);
|
||||
} finally {
|
||||
setTriggerScheduleSaving(false);
|
||||
}
|
||||
}}
|
||||
className="mt-1.5 w-full text-[11px] px-3 py-1.5 rounded-lg border border-primary/30 text-primary hover:bg-primary/10 transition-colors disabled:opacity-50"
|
||||
>
|
||||
{triggerScheduleSaving ? "Saving..." : triggerCronSaved ? "Saved" : "Save Cron"}
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
) : null;
|
||||
})()}
|
||||
@@ -3090,24 +3367,27 @@ export default function Workspace() {
|
||||
{(() => {
|
||||
const currentTask = (resolvedSelectedNode.triggerConfig as Record<string, unknown> | undefined)?.task as string || "";
|
||||
const hasChanged = triggerTaskDraft !== currentTask;
|
||||
if (!hasChanged) return null;
|
||||
if (!hasChanged && !triggerTaskSaved) return null;
|
||||
return (
|
||||
<button
|
||||
disabled={triggerTaskSaving}
|
||||
disabled={triggerTaskSaving || !hasChanged}
|
||||
onClick={async () => {
|
||||
const sessionId = activeAgentState?.sessionId;
|
||||
const triggerId = resolvedSelectedNode.id.replace("__trigger_", "");
|
||||
if (!sessionId) return;
|
||||
setTriggerTaskSaving(true);
|
||||
try {
|
||||
await sessionsApi.updateTriggerTask(sessionId, triggerId, triggerTaskDraft);
|
||||
await sessionsApi.updateTrigger(sessionId, triggerId, { task: triggerTaskDraft });
|
||||
patchTriggerNode(activeWorker, resolvedSelectedNode.id, { task: triggerTaskDraft });
|
||||
setTriggerTaskSaved(true);
|
||||
setTimeout(() => setTriggerTaskSaved(false), 2000);
|
||||
} finally {
|
||||
setTriggerTaskSaving(false);
|
||||
}
|
||||
}}
|
||||
className="mt-1.5 w-full text-[11px] px-3 py-1.5 rounded-lg border border-primary/30 text-primary hover:bg-primary/10 transition-colors disabled:opacity-50"
|
||||
>
|
||||
{triggerTaskSaving ? "Saving..." : "Save Task"}
|
||||
{triggerTaskSaving ? "Saving..." : triggerTaskSaved ? "Saved" : "Save Task"}
|
||||
</button>
|
||||
);
|
||||
})()}
|
||||
@@ -3164,6 +3444,7 @@ export default function Workspace() {
|
||||
workerSessionId={null}
|
||||
nodeLogs={activeAgentState?.nodeLogs[resolvedSelectedNode.id] || []}
|
||||
actionPlan={activeAgentState?.nodeActionPlans[resolvedSelectedNode.id]}
|
||||
contextUsage={activeAgentState?.contextUsage[resolvedSelectedNode.id]}
|
||||
onClose={() => setSelectedNode(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -62,6 +62,10 @@ lint.isort.section-order = [
|
||||
"first-party",
|
||||
"local-folder",
|
||||
]
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning:litellm.*"
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Integration test: Run a real EventLoopNode against the Antigravity backend.
|
||||
|
||||
Run: .venv/bin/python core/tests/test_antigravity_eventloop.py
|
||||
|
||||
Requires:
|
||||
- ~/.hive/antigravity-accounts.json with valid credentials
|
||||
(run 'uv run python core/antigravity_auth.py auth account add' to authenticate)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")
|
||||
# Show our provider's retry/stream logs
|
||||
logging.getLogger("framework.llm.litellm").setLevel(logging.DEBUG)
|
||||
|
||||
from framework.config import RuntimeConfig # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
|
||||
|
||||
def make_provider() -> LiteLLMProvider:
|
||||
cfg = RuntimeConfig()
|
||||
if not cfg.api_key:
|
||||
print("ERROR: No Antigravity token found.")
|
||||
print(" 1. Run 'antigravity-auth accounts add' to authenticate.")
|
||||
print(" 2. Run 'antigravity-auth serve' to start the local proxy.")
|
||||
print(" 3. Configure Hive: run quickstart.sh and select option 7 (Antigravity).")
|
||||
sys.exit(1)
|
||||
print(f"Model : {cfg.model}")
|
||||
print(f"Base : {cfg.api_base}")
|
||||
print(f"Antigravity : {'localhost:8069' in (cfg.api_base or '')}")
|
||||
return LiteLLMProvider(
|
||||
model=cfg.model,
|
||||
api_key=cfg.api_key,
|
||||
api_base=cfg.api_base,
|
||||
**cfg.extra_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_context(
|
||||
llm: LiteLLMProvider,
|
||||
*,
|
||||
node_id: str = "test",
|
||||
system_prompt: str = "You are a helpful assistant.",
|
||||
output_keys: list[str] | None = None,
|
||||
) -> NodeContext:
|
||||
if output_keys is None:
|
||||
output_keys = ["answer"]
|
||||
|
||||
spec = NodeSpec(
|
||||
id=node_id,
|
||||
name="Test Node",
|
||||
description="Integration test node",
|
||||
node_type="event_loop",
|
||||
output_keys=output_keys,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.start_run = MagicMock(return_value="run-1")
|
||||
runtime.decide = MagicMock(return_value="dec-1")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
|
||||
memory = SharedMemory()
|
||||
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_id,
|
||||
node_spec=spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=llm,
|
||||
available_tools=[],
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
async def run_test(
|
||||
name: str, llm: LiteLLMProvider, system: str, output_keys: list[str]
|
||||
) -> NodeResult:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"TEST: {name}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
ctx = make_context(llm, system_prompt=system, output_keys=output_keys)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=3))
|
||||
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
print(f" Success : {result.success}")
|
||||
print(f" Output : {result.output}")
|
||||
if result.error:
|
||||
print(f" Error : {result.error}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return NodeResult(success=False, error=str(e))
|
||||
|
||||
|
||||
async def main():
|
||||
llm = make_provider()
|
||||
print()
|
||||
|
||||
# Test 1: Simple text output — the node should call set_output to fill "answer"
|
||||
r1 = await run_test(
|
||||
name="Simple text generation",
|
||||
llm=llm,
|
||||
system=(
|
||||
"You are a helpful assistant. When asked a question, use the "
|
||||
"set_output tool to store your answer in the 'answer' key. "
|
||||
"Keep answers short (1-2 sentences)."
|
||||
),
|
||||
output_keys=["answer"],
|
||||
)
|
||||
|
||||
# Test 2: If test 1 failed, try bare stream() to isolate the issue
|
||||
if not r1.success:
|
||||
print(f"\n{'=' * 60}")
|
||||
print("FALLBACK: Testing bare provider.stream() directly")
|
||||
print(f"{'=' * 60}")
|
||||
try:
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
text = ""
|
||||
events = []
|
||||
async for event in llm.stream(
|
||||
messages=[{"role": "user", "content": "Say hello in 3 words."}],
|
||||
):
|
||||
events.append(type(event).__name__)
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f" Finish: stop={event.stop_reason}"
|
||||
f" in={event.input_tokens}"
|
||||
f" out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f" ToolCall: {event.tool_name}")
|
||||
print(f" Text : {text!r}")
|
||||
print(f" Events : {events}")
|
||||
print(f" RESULT : {'OK' if text else 'EMPTY'}")
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("DONE")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Tests for LLM model capability checks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.capabilities import supports_image_tool_results
|
||||
|
||||
|
||||
class TestSupportsImageToolResults:
|
||||
"""Verify the deny-list correctly identifies models that can't handle images."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"openai/gpt-4o",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"google/gemini-1.5-flash",
|
||||
"mistral/mistral-large",
|
||||
"groq/llama3-70b",
|
||||
"together/meta-llama/Llama-3-70b",
|
||||
"fireworks_ai/llama-v3-70b",
|
||||
"azure/gpt-4o",
|
||||
"kimi/claude-sonnet-4-20250514",
|
||||
"hive/claude-sonnet-4-20250514",
|
||||
],
|
||||
)
|
||||
def test_supported_models(self, model: str):
|
||||
assert supports_image_tool_results(model) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"deepseek/deepseek-chat",
|
||||
"deepseek/deepseek-coder",
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
"ollama/llama3",
|
||||
"ollama/mistral",
|
||||
"ollama_chat/llama3",
|
||||
"lm_studio/my-model",
|
||||
"vllm/meta-llama/Llama-3-70b",
|
||||
"llamacpp/model",
|
||||
"cerebras/llama3-70b",
|
||||
],
|
||||
)
|
||||
def test_unsupported_models(self, model: str):
|
||||
assert supports_image_tool_results(model) is False
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert supports_image_tool_results("DeepSeek/deepseek-chat") is False
|
||||
assert supports_image_tool_results("OLLAMA/llama3") is False
|
||||
assert supports_image_tool_results("GPT-4o") is True
|
||||
@@ -0,0 +1,209 @@
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_check_llm_key_module():
|
||||
module_path = Path(__file__).resolve().parents[2] / "scripts" / "check_llm_key.py"
|
||||
spec = importlib.util.spec_from_file_location("check_llm_key_script", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _run_openrouter_check(monkeypatch, status_code: int):
|
||||
module = _load_check_llm_key_module()
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, code):
|
||||
self.status_code = code
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout):
|
||||
calls["timeout"] = timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, endpoint, headers):
|
||||
calls["endpoint"] = endpoint
|
||||
calls["headers"] = headers
|
||||
return FakeResponse(status_code)
|
||||
|
||||
monkeypatch.setattr(module.httpx, "Client", FakeClient)
|
||||
result = module.check_openrouter("test-key")
|
||||
return result, calls
|
||||
|
||||
|
||||
def _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
status_code: int,
|
||||
payload: dict | None = None,
|
||||
model: str = "openai/gpt-4o-mini",
|
||||
):
|
||||
module = _load_check_llm_key_module()
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, code):
|
||||
self.status_code = code
|
||||
self._payload = payload
|
||||
self.text = ""
|
||||
|
||||
def json(self):
|
||||
if self._payload is None:
|
||||
raise ValueError("no json")
|
||||
return self._payload
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout):
|
||||
calls["timeout"] = timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, endpoint, headers):
|
||||
calls["endpoint"] = endpoint
|
||||
calls["headers"] = headers
|
||||
return FakeResponse(status_code)
|
||||
|
||||
monkeypatch.setattr(module.httpx, "Client", FakeClient)
|
||||
result = module.check_openrouter_model("test-key", model)
|
||||
return result, calls
|
||||
|
||||
|
||||
def test_check_openrouter_200(monkeypatch):
|
||||
result, calls = _run_openrouter_check(monkeypatch, 200)
|
||||
assert result == {"valid": True, "message": "OpenRouter API key valid"}
|
||||
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models"
|
||||
assert calls["headers"] == {"Authorization": "Bearer test-key"}
|
||||
|
||||
|
||||
def test_check_openrouter_401(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 401)
|
||||
assert result == {"valid": False, "message": "Invalid OpenRouter API key"}
|
||||
|
||||
|
||||
def test_check_openrouter_403(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 403)
|
||||
assert result == {"valid": False, "message": "OpenRouter API key lacks permissions"}
|
||||
|
||||
|
||||
def test_check_openrouter_429(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 429)
|
||||
assert result == {"valid": True, "message": "OpenRouter API key valid"}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200(monkeypatch):
|
||||
result, calls = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "openai/gpt-4o-mini",
|
||||
"canonical_slug": "openai/gpt-4o-mini",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: openai/gpt-4o-mini",
|
||||
"model": "openai/gpt-4o-mini",
|
||||
}
|
||||
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models/user"
|
||||
assert calls["headers"] == {"Authorization": "Bearer test-key"}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_matches_canonical_slug(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "mistralai/mistral-small-4",
|
||||
"canonical_slug": "mistralai/mistral-small-2603",
|
||||
}
|
||||
]
|
||||
},
|
||||
model="mistralai/mistral-small-2603",
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: mistralai/mistral-small-2603",
|
||||
"model": "mistralai/mistral-small-2603",
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_sanitizes_pasted_unicode(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "z-ai/glm-5-turbo",
|
||||
"canonical_slug": "z-ai/glm-5-turbo",
|
||||
}
|
||||
]
|
||||
},
|
||||
model="openrouter/z-ai\u200b/glm\u20115\u2011turbo",
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: z-ai/glm-5-turbo",
|
||||
"model": "z-ai/glm-5-turbo",
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_not_found_with_suggestions(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{"id": "z-ai/glm-5-turbo"},
|
||||
{"id": "z-ai/glm-4.6v"},
|
||||
]
|
||||
},
|
||||
model="z-ai/glm-5-turb",
|
||||
)
|
||||
assert result == {
|
||||
"valid": False,
|
||||
"message": (
|
||||
"OpenRouter model is not available for this key/settings: z-ai/glm-5-turb. "
|
||||
"Closest matches: z-ai/glm-5-turbo"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_404_with_error_message(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
404,
|
||||
{"error": {"message": "No endpoints available for this model"}},
|
||||
)
|
||||
assert result == {
|
||||
"valid": False,
|
||||
"message": (
|
||||
"OpenRouter model is not available for this key/settings: openai/gpt-4o-mini. "
|
||||
"No endpoints available for this model"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_429(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(monkeypatch, 429)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model check rate-limited; assuming model is reachable",
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from framework.config import get_hive_config
|
||||
from framework.config import get_api_base, get_hive_config, get_preferred_model
|
||||
|
||||
|
||||
class TestGetHiveConfig:
|
||||
@@ -21,3 +21,47 @@ class TestGetHiveConfig:
|
||||
assert result == {}
|
||||
assert "Failed to load Hive config" in caplog.text
|
||||
assert str(config_file) in caplog.text
|
||||
|
||||
|
||||
class TestOpenRouterConfig:
|
||||
"""OpenRouter config composition and fallback behavior."""
|
||||
|
||||
def test_get_preferred_model_for_openrouter(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
|
||||
|
||||
def test_get_preferred_model_normalizes_openrouter_prefixed_model(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"openrouter/x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
|
||||
|
||||
def test_get_api_base_falls_back_to_openrouter_default(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_get_api_base_keeps_explicit_openrouter_api_base(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta","api_base":"https://proxy.example/v1"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == "https://proxy.example/v1"
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
from framework.credentials import key_storage
|
||||
from framework.credentials.validation import ensure_credential_key_env
|
||||
|
||||
|
||||
def _install_fake_aden_modules(monkeypatch, check_fn, credential_specs):
|
||||
shell_config_module = ModuleType("aden_tools.credentials.shell_config")
|
||||
shell_config_module.check_env_var_in_shell_config = check_fn
|
||||
|
||||
credentials_module = ModuleType("aden_tools.credentials")
|
||||
credentials_module.CREDENTIAL_SPECS = credential_specs
|
||||
|
||||
monkeypatch.setitem(sys.modules, "aden_tools.credentials.shell_config", shell_config_module)
|
||||
monkeypatch.setitem(sys.modules, "aden_tools.credentials", credentials_module)
|
||||
|
||||
|
||||
def test_bootstrap_loads_configured_llm_env_var_from_shell_config(monkeypatch):
|
||||
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
|
||||
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"framework.config.get_hive_config",
|
||||
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
|
||||
)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
|
||||
calls = []
|
||||
|
||||
def check_env(var_name):
|
||||
calls.append(var_name)
|
||||
if var_name == "OPENROUTER_API_KEY":
|
||||
return True, "or-key-123"
|
||||
return False, None
|
||||
|
||||
_install_fake_aden_modules(
|
||||
monkeypatch,
|
||||
check_env,
|
||||
{"anthropic": SimpleNamespace(env_var="ANTHROPIC_API_KEY")},
|
||||
)
|
||||
|
||||
ensure_credential_key_env()
|
||||
|
||||
assert os.environ.get("OPENROUTER_API_KEY") == "or-key-123"
|
||||
assert "OPENROUTER_API_KEY" in calls
|
||||
|
||||
|
||||
def test_bootstrap_does_not_override_existing_configured_llm_env_var(monkeypatch):
|
||||
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
|
||||
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"framework.config.get_hive_config",
|
||||
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
|
||||
)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "already-set")
|
||||
|
||||
calls = []
|
||||
|
||||
def check_env(var_name):
|
||||
calls.append(var_name)
|
||||
return True, "new-value-should-not-apply"
|
||||
|
||||
_install_fake_aden_modules(monkeypatch, check_env, {})
|
||||
|
||||
ensure_credential_key_env()
|
||||
|
||||
assert os.environ.get("OPENROUTER_API_KEY") == "already-set"
|
||||
assert "OPENROUTER_API_KEY" not in calls
|
||||
@@ -1530,6 +1530,34 @@ class TestTransientErrorRetry:
|
||||
await node.execute(ctx)
|
||||
assert llm._call_index == 1 # only tried once
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_non_transient_error_does_not_crash(
|
||||
self, runtime, node_spec, memory
|
||||
):
|
||||
"""Client-facing non-transient errors should wait for input, not crash on token vars."""
|
||||
node_spec.output_keys = []
|
||||
node_spec.client_facing = True
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ValueError("bad request: blocked by policy"),
|
||||
fail_count=100, # always fails
|
||||
success_scenario=text_scenario("unreachable"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=1,
|
||||
max_stream_retries=0,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
node._await_user_input = AsyncMock(return_value=None)
|
||||
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Max iterations" in (result.error or "")
|
||||
node._await_user_input.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_exhausts_retries(self, runtime, node_spec, memory):
|
||||
"""Transient errors that exhaust retries should raise."""
|
||||
|
||||
@@ -19,7 +19,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
|
||||
from framework.llm.litellm import (
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
|
||||
LiteLLMProvider,
|
||||
_compute_retry_delay,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
|
||||
@@ -72,6 +76,20 @@ class TestLiteLLMProviderInit:
|
||||
)
|
||||
assert provider.api_base == "https://proxy.example/v1"
|
||||
|
||||
def test_init_openrouter_defaults_api_base(self):
|
||||
"""OpenRouter should default to the official OpenAI-compatible endpoint."""
|
||||
provider = LiteLLMProvider(model="openrouter/x-ai/grok-4.20-beta", api_key="my-key")
|
||||
assert provider.api_base == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_init_openrouter_keeps_custom_api_base(self):
|
||||
"""Explicit api_base should win over OpenRouter defaults."""
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/x-ai/grok-4.20-beta",
|
||||
api_key="my-key",
|
||||
api_base="https://proxy.example/v1",
|
||||
)
|
||||
assert provider.api_base == "https://proxy.example/v1"
|
||||
|
||||
def test_init_ollama_no_key_needed(self):
|
||||
"""Test that Ollama models don't require API key."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
@@ -192,6 +210,34 @@ class TestToolConversion:
|
||||
assert result["function"]["parameters"]["properties"]["query"]["type"] == "string"
|
||||
assert result["function"]["parameters"]["required"] == ["query"]
|
||||
|
||||
def test_parse_tool_call_arguments_repairs_truncated_json(self):
|
||||
"""Truncated JSON fragments should be repaired into valid tool inputs."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
parsed = provider._parse_tool_call_arguments(
|
||||
(
|
||||
'{"question":"What story structure should the agent use?",'
|
||||
'"options":["3-act structure","Beginning-Middle-End","Random paragraph"'
|
||||
),
|
||||
"ask_user",
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
"question": "What story structure should the agent use?",
|
||||
"options": [
|
||||
"3-act structure",
|
||||
"Beginning-Middle-End",
|
||||
"Random paragraph",
|
||||
],
|
||||
}
|
||||
|
||||
def test_parse_tool_call_arguments_raises_when_unrepairable(self):
|
||||
"""Completely invalid JSON should fail fast instead of producing _raw loops."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
|
||||
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
|
||||
|
||||
|
||||
class TestAnthropicProviderBackwardCompatibility:
|
||||
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
|
||||
@@ -682,6 +728,315 @@ class TestMiniMaxStreamFallback:
|
||||
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
|
||||
|
||||
|
||||
class TestOpenRouterToolCompatFallback:
|
||||
"""OpenRouter models should fall back when native tool use is unavailable."""
|
||||
|
||||
def teardown_method(self):
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_falls_back_to_json_tool_emulation(self, mock_acompletion):
|
||||
"""OpenRouter tool-use 404s should emit synthetic ToolCallEvents instead of errors."""
|
||||
from framework.llm.stream_events import FinishEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"num_results": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
'{"assistant_response":"","tool_calls":['
|
||||
'{"name":"web_search","arguments":'
|
||||
'{"query":"Python 3.13 release notes","num_results":3}}'
|
||||
"]}"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 18
|
||||
compat_response.usage.completion_tokens = 9
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
"that support tool use. To learn more about provider routing, "
|
||||
'visit: https://openrouter.ai/docs/guides/routing/provider-selection",'
|
||||
'"code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Search for the Python 3.13 release notes."}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=256,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "web_search"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"query": "Python 3.13 release notes",
|
||||
"num_results": 3,
|
||||
}
|
||||
assert tool_calls[0].tool_use_id.startswith("openrouter_compat_")
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "tool_calls"
|
||||
assert finish_events[0].input_tokens == 18
|
||||
assert finish_events[0].output_tokens == 9
|
||||
|
||||
assert mock_acompletion.call_count == 2
|
||||
first_call = mock_acompletion.call_args_list[0].kwargs
|
||||
assert first_call["stream"] is True
|
||||
assert "tools" in first_call
|
||||
|
||||
second_call = mock_acompletion.call_args_list[1].kwargs
|
||||
assert "tools" not in second_call
|
||||
assert "Tool compatibility mode is active" in second_call["messages"][0]["content"]
|
||||
assert provider.model in OPENROUTER_TOOL_COMPAT_MODEL_CACHE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_parses_textual_tool_calls_and_uses_cache(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""Textual tool-call markers should become ToolCallEvents and skip repeat probing."""
|
||||
from framework.llm.stream_events import ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="ask_user_multiple",
|
||||
description="Ask the user a multiple-choice question",
|
||||
parameters={
|
||||
"properties": {
|
||||
"options": {"type": "array"},
|
||||
"question": {"type": "string"},
|
||||
"prompt": {"type": "string"},
|
||||
},
|
||||
"required": ["options", "question", "prompt"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
"<|tool_call_start|>"
|
||||
"[ask_user_multiple(options=['Quartet Collaborator', 'Project Advisor'], "
|
||||
"question='Who are you?', prompt='Who are you?')]"
|
||||
"<|tool_call_end|>"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 10
|
||||
compat_response.usage.completion_tokens = 5
|
||||
|
||||
call_state = {"count": 0}
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
call_state["count"] += 1
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
first_events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Who are you?"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
first_events.append(event)
|
||||
|
||||
tool_calls = [event for event in first_events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "ask_user_multiple"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"options": ["Quartet Collaborator", "Project Advisor"],
|
||||
"question": "Who are you?",
|
||||
"prompt": "Who are you?",
|
||||
}
|
||||
|
||||
second_events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Who are you?"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
second_events.append(event)
|
||||
|
||||
second_tool_calls = [event for event in second_events if isinstance(event, ToolCallEvent)]
|
||||
assert len(second_tool_calls) == 1
|
||||
assert mock_acompletion.call_count == 3
|
||||
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
|
||||
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
|
||||
assert "stream" not in mock_acompletion.call_args_list[2].kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_parses_plain_text_tool_call_lines(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""Plain textual tool-call lines should execute as tools, not user-visible text."""
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="ask_user",
|
||||
description="Ask the user a single multiple-choice question",
|
||||
parameters={
|
||||
"properties": {
|
||||
"question": {"type": "string"},
|
||||
"options": {"type": "array"},
|
||||
},
|
||||
"required": ["question", "options"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
"Queen has been loaded. It's ready to assist with your planning needs.\n\n"
|
||||
"ask_user('What would you like to do?', ['Define a new agent', "
|
||||
"'Diagnose an existing agent', 'Explore tools'])"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 11
|
||||
compat_response.usage.completion_tokens = 7
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "ask_user"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"question": "What would you like to do?",
|
||||
"options": ["Define a new agent", "Diagnose an existing agent", "Explore tools"],
|
||||
}
|
||||
|
||||
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
||||
assert len(text_events) == 1
|
||||
assert "ask_user(" not in text_events[0].snapshot
|
||||
assert text_events[0].snapshot == (
|
||||
"Queen has been loaded. It's ready to assist with your planning needs."
|
||||
)
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "tool_calls"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_treats_non_json_as_plain_text(self, mock_acompletion):
|
||||
"""If fallback output is not valid JSON, preserve it as assistant text."""
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = "I can answer directly without tools."
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 12
|
||||
compat_response.usage.completion_tokens = 6
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Say hello."}],
|
||||
system="Be concise.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
||||
assert len(text_events) == 1
|
||||
assert text_events[0].snapshot == "I can answer directly without tools."
|
||||
assert not any(isinstance(event, ToolCallEvent) for event in events)
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "stop"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentRunner._is_local_model — parameterized tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Unit tests for MCP client transport and reconnect behavior."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from framework.runner import mcp_client as mcp_client_module
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig, MCPTool
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload=None):
|
||||
self._payload = payload or {}
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""Pretend the request succeeded."""
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _FakeHttpClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.get_calls: list[str] = []
|
||||
self.closed = False
|
||||
|
||||
def get(self, path: str) -> _FakeResponse:
|
||||
self.get_calls.append(path)
|
||||
return _FakeResponse()
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def test_connect_unix_transport_uses_socket_path(monkeypatch):
|
||||
created = {}
|
||||
|
||||
class FakeHTTPTransport:
|
||||
def __init__(self, *, uds: str):
|
||||
created["uds"] = uds
|
||||
self.uds = uds
|
||||
|
||||
def fake_client_factory(**kwargs):
|
||||
client = _FakeHttpClient(**kwargs)
|
||||
created["client"] = client
|
||||
return client
|
||||
|
||||
monkeypatch.setattr(mcp_client_module.httpx, "HTTPTransport", FakeHTTPTransport)
|
||||
monkeypatch.setattr(mcp_client_module.httpx, "Client", fake_client_factory)
|
||||
monkeypatch.setattr(MCPClient, "_discover_tools", lambda self: None)
|
||||
|
||||
client = MCPClient(
|
||||
MCPServerConfig(
|
||||
name="unix-server",
|
||||
transport="unix",
|
||||
url="http://localhost",
|
||||
socket_path="/tmp/test.sock",
|
||||
)
|
||||
)
|
||||
|
||||
client.connect()
|
||||
|
||||
assert created["uds"] == "/tmp/test.sock"
|
||||
assert client._http_client is created["client"] # noqa: SLF001 - direct unit test
|
||||
assert created["client"].kwargs["base_url"] == "http://localhost"
|
||||
assert created["client"].get_calls == ["/health"]
|
||||
|
||||
client.disconnect()
|
||||
assert created["client"].closed is True
|
||||
|
||||
|
||||
def test_connect_sse_and_list_tools(monkeypatch):
|
||||
pytest.importorskip("mcp")
|
||||
sse_module = pytest.importorskip("mcp.client.sse")
|
||||
import mcp
|
||||
|
||||
contexts = []
|
||||
|
||||
class FakeSSEContext:
|
||||
def __init__(self, url: str, headers: dict[str, str] | None, timeout: float):
|
||||
self.url = url
|
||||
self.headers = headers
|
||||
self.timeout = timeout
|
||||
self.exited = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return "read-stream", "write-stream"
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.exited = True
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, read_stream, write_stream):
|
||||
self.read_stream = read_stream
|
||||
self.write_stream = write_stream
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.closed = True
|
||||
|
||||
async def initialize(self):
|
||||
"""Pretend session initialization succeeded."""
|
||||
|
||||
async def list_tools(self):
|
||||
return SimpleNamespace(
|
||||
tools=[
|
||||
SimpleNamespace(
|
||||
name="search",
|
||||
description="Search docs",
|
||||
inputSchema={"type": "object"},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def fake_sse_client(url: str, headers=None, timeout=5, **_kwargs):
|
||||
context = FakeSSEContext(url=url, headers=headers, timeout=timeout)
|
||||
contexts.append(context)
|
||||
return context
|
||||
|
||||
monkeypatch.setattr(sse_module, "sse_client", fake_sse_client)
|
||||
monkeypatch.setattr(mcp, "ClientSession", FakeSession)
|
||||
|
||||
client = MCPClient(
|
||||
MCPServerConfig(
|
||||
name="sse-server",
|
||||
transport="sse",
|
||||
url="http://localhost/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
)
|
||||
|
||||
client.connect()
|
||||
tools = client.list_tools()
|
||||
|
||||
assert [tool.name for tool in tools] == ["search"]
|
||||
assert tools[0].description == "Search docs"
|
||||
assert contexts[0].url == "http://localhost/sse"
|
||||
assert contexts[0].headers == {"Authorization": "Bearer token"}
|
||||
assert contexts[0].timeout == 30.0
|
||||
|
||||
client.disconnect()
|
||||
assert contexts[0].exited is True
|
||||
|
||||
|
||||
def test_call_tool_retries_once_on_connect_error_for_unix(monkeypatch):
|
||||
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
|
||||
client._connected = True # noqa: SLF001 - direct unit test
|
||||
client._tools = { # noqa: SLF001 - direct unit test
|
||||
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
|
||||
}
|
||||
|
||||
first_error = httpx.ConnectError("first failure")
|
||||
calls = {"count": 0}
|
||||
reconnects = []
|
||||
|
||||
def fake_call_tool_http(tool_name, arguments):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
raise first_error
|
||||
return [{"type": "text", "text": f"{tool_name}:{arguments['value']}"}]
|
||||
|
||||
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
|
||||
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
|
||||
|
||||
result = client.call_tool("ping", {"value": "ok"})
|
||||
|
||||
assert result == [{"type": "text", "text": "ping:ok"}]
|
||||
assert calls["count"] == 2
|
||||
assert reconnects == ["reconnected"]
|
||||
|
||||
|
||||
def test_call_tool_retry_exhausted_raises_original_error_for_unix(monkeypatch):
|
||||
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
|
||||
client._connected = True # noqa: SLF001 - direct unit test
|
||||
client._tools = { # noqa: SLF001 - direct unit test
|
||||
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
|
||||
}
|
||||
|
||||
first_error = httpx.ConnectError("first failure")
|
||||
second_error = httpx.ConnectError("second failure")
|
||||
calls = {"count": 0}
|
||||
reconnects = []
|
||||
|
||||
def fake_call_tool_http(_tool_name, _arguments):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
raise first_error
|
||||
raise second_error
|
||||
|
||||
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
|
||||
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
|
||||
|
||||
with pytest.raises(httpx.ConnectError) as exc_info:
|
||||
client.call_tool("ping", {"value": "ok"})
|
||||
|
||||
assert exc_info.value is first_error
|
||||
assert calls["count"] == 2
|
||||
assert reconnects == ["reconnected"]
|
||||
|
||||
|
||||
def test_call_tool_http_preserves_runtime_error_wrapping(monkeypatch):
|
||||
client = MCPClient(MCPServerConfig(name="http-server", transport="http"))
|
||||
client._connected = True # noqa: SLF001 - direct unit test
|
||||
client._tools = { # noqa: SLF001 - direct unit test
|
||||
"ping": MCPTool("ping", "Ping tool", {}, "http-server")
|
||||
}
|
||||
|
||||
connect_error = httpx.ConnectError("first failure")
|
||||
|
||||
class FailingHttpClient:
|
||||
def post(self, _path, json):
|
||||
raise connect_error
|
||||
|
||||
client._http_client = FailingHttpClient() # noqa: SLF001 - direct unit test
|
||||
reconnects = []
|
||||
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
client.call_tool("ping", {"value": "ok"})
|
||||
|
||||
assert "Failed to call tool via HTTP" in str(exc_info.value)
|
||||
assert exc_info.value.__cause__ is connect_error
|
||||
assert reconnects == []
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Tests for the shared MCP connection manager."""
|
||||
|
||||
import threading
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from framework.runner.mcp_client import MCPServerConfig, MCPTool
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
|
||||
class FakeMCPClient:
|
||||
"""Minimal fake MCP client for connection manager tests."""
|
||||
|
||||
instances: list["FakeMCPClient"] = []
|
||||
|
||||
def __init__(self, config: MCPServerConfig):
|
||||
self.config = config
|
||||
self._connected = False
|
||||
self.connect_calls = 0
|
||||
self.disconnect_calls = 0
|
||||
self.list_tools_calls = 0
|
||||
self.list_tools_error: Exception | None = None
|
||||
FakeMCPClient.instances.append(self)
|
||||
|
||||
def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
self._connected = True
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.disconnect_calls += 1
|
||||
self._connected = False
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
self.list_tools_calls += 1
|
||||
if self.list_tools_error is not None:
|
||||
raise self.list_tools_error
|
||||
return [MCPTool("ping", "Ping", {"type": "object"}, self.config.name)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(monkeypatch):
|
||||
monkeypatch.setattr("framework.runner.mcp_connection_manager.MCPClient", FakeMCPClient)
|
||||
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
|
||||
FakeMCPClient.instances.clear()
|
||||
manager = MCPConnectionManager.get_instance()
|
||||
yield manager
|
||||
manager.cleanup_all()
|
||||
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
|
||||
FakeMCPClient.instances.clear()
|
||||
|
||||
|
||||
def test_acquire_returns_same_client_for_same_server_name(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
|
||||
client_one = manager.acquire(config)
|
||||
client_two = manager.acquire(config)
|
||||
|
||||
assert client_one is client_two
|
||||
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
|
||||
assert len(FakeMCPClient.instances) == 1
|
||||
|
||||
|
||||
def test_release_with_refcount_above_one_keeps_connection_open(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
client = manager.acquire(config)
|
||||
manager.acquire(config)
|
||||
|
||||
manager.release("shared")
|
||||
|
||||
assert client.disconnect_calls == 0
|
||||
assert manager._pool["shared"] is client # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts["shared"] == 1 # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_release_last_reference_disconnects_and_removes_from_pool(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
client = manager.acquire(config)
|
||||
|
||||
manager.release("shared")
|
||||
|
||||
assert client.disconnect_calls == 1
|
||||
assert "shared" not in manager._pool # noqa: SLF001 - state assertion for unit test
|
||||
assert "shared" not in manager._refcounts # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_concurrent_acquire_and_release_keeps_state_consistent(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
worker_count = 8
|
||||
acquire_barrier = threading.Barrier(worker_count + 1)
|
||||
release_barrier = threading.Barrier(worker_count)
|
||||
acquired_clients: list[FakeMCPClient] = []
|
||||
acquired_lock = threading.Lock()
|
||||
|
||||
def worker() -> None:
|
||||
acquire_barrier.wait()
|
||||
client = manager.acquire(config)
|
||||
with acquired_lock:
|
||||
acquired_clients.append(client)
|
||||
release_barrier.wait()
|
||||
manager.release("shared")
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(worker_count)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
acquire_barrier.wait()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert len({id(client) for client in acquired_clients}) == 1
|
||||
assert len(FakeMCPClient.instances) == 1
|
||||
assert FakeMCPClient.instances[0].disconnect_calls == 1
|
||||
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_cleanup_all_disconnects_every_pooled_client(manager):
|
||||
manager.acquire(MCPServerConfig(name="one", transport="stdio", command="echo"))
|
||||
manager.acquire(MCPServerConfig(name="two", transport="stdio", command="echo"))
|
||||
|
||||
manager.cleanup_all()
|
||||
|
||||
assert len(FakeMCPClient.instances) == 2
|
||||
assert all(client.disconnect_calls == 1 for client in FakeMCPClient.instances)
|
||||
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._configs == {} # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_reconnect_replaces_client_even_with_existing_refcount(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
original_client = manager.acquire(config)
|
||||
manager.acquire(config)
|
||||
|
||||
replacement = manager.reconnect("shared")
|
||||
|
||||
assert replacement is not original_client
|
||||
assert original_client.disconnect_calls == 1
|
||||
assert manager._pool["shared"] is replacement # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_health_check_returns_false_when_server_is_unreachable(manager, monkeypatch):
|
||||
config = MCPServerConfig(name="shared", transport="http", url="http://localhost:9")
|
||||
manager.acquire(config)
|
||||
|
||||
class FailingHttpClient:
|
||||
def __init__(self, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, _path: str):
|
||||
raise httpx.ConnectError("unreachable")
|
||||
|
||||
monkeypatch.setattr("framework.runner.mcp_connection_manager.httpx.Client", FailingHttpClient)
|
||||
|
||||
assert manager.health_check("shared") is False
|
||||
|
||||
|
||||
def test_health_check_for_stdio_returns_false_on_tools_list_error(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
client = manager.acquire(config)
|
||||
client.list_tools_error = RuntimeError("broken")
|
||||
|
||||
assert manager.health_check("shared") is False
|
||||
@@ -21,3 +21,8 @@ def test_minimax_provider_prefix_maps_to_minimax_api_key():
|
||||
def test_minimax_model_name_prefix_maps_to_minimax_api_key():
|
||||
runner = _runner_for_unit_test()
|
||||
assert runner._get_api_key_env_var("minimax-chat") == "MINIMAX_API_KEY"
|
||||
|
||||
|
||||
def test_openrouter_provider_prefix_maps_to_openrouter_api_key():
|
||||
runner = _runner_for_unit_test()
|
||||
assert runner._get_api_key_env_var("openrouter/x-ai/grok-4.20-beta") == "OPENROUTER_API_KEY"
|
||||
|
||||
@@ -0,0 +1,520 @@
|
||||
"""Tests for safe_eval — the sandboxed expression evaluator used by edge conditions.
|
||||
|
||||
Covers: literals, data structures, arithmetic, comparisons, boolean logic
|
||||
(including short-circuit semantics), variable lookup, subscript/attribute
|
||||
access, whitelisted function calls, method calls, ternary expressions,
|
||||
chained comparisons, and security boundaries (private attrs, disallowed
|
||||
AST nodes, disallowed function calls).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Literals and constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLiterals:
|
||||
def test_integer(self):
|
||||
assert safe_eval("42") == 42
|
||||
|
||||
def test_negative_integer(self):
|
||||
assert safe_eval("-1") == -1
|
||||
|
||||
def test_float(self):
|
||||
assert safe_eval("3.14") == pytest.approx(3.14)
|
||||
|
||||
def test_string(self):
|
||||
assert safe_eval("'hello'") == "hello"
|
||||
|
||||
def test_double_quoted_string(self):
|
||||
assert safe_eval('"world"') == "world"
|
||||
|
||||
def test_boolean_true(self):
|
||||
assert safe_eval("True") is True
|
||||
|
||||
def test_boolean_false(self):
|
||||
assert safe_eval("False") is False
|
||||
|
||||
def test_none(self):
|
||||
assert safe_eval("None") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDataStructures:
|
||||
def test_list(self):
|
||||
assert safe_eval("[1, 2, 3]") == [1, 2, 3]
|
||||
|
||||
def test_empty_list(self):
|
||||
assert safe_eval("[]") == []
|
||||
|
||||
def test_nested_list(self):
|
||||
assert safe_eval("[[1, 2], [3, 4]]") == [[1, 2], [3, 4]]
|
||||
|
||||
def test_tuple(self):
|
||||
assert safe_eval("(1, 2, 3)") == (1, 2, 3)
|
||||
|
||||
def test_dict(self):
|
||||
assert safe_eval("{'a': 1, 'b': 2}") == {"a": 1, "b": 2}
|
||||
|
||||
def test_empty_dict(self):
|
||||
assert safe_eval("{}") == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Arithmetic and binary operators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestArithmetic:
|
||||
def test_addition(self):
|
||||
assert safe_eval("2 + 3") == 5
|
||||
|
||||
def test_subtraction(self):
|
||||
assert safe_eval("10 - 4") == 6
|
||||
|
||||
def test_multiplication(self):
|
||||
assert safe_eval("3 * 7") == 21
|
||||
|
||||
def test_division(self):
|
||||
assert safe_eval("10 / 4") == 2.5
|
||||
|
||||
def test_floor_division(self):
|
||||
assert safe_eval("10 // 3") == 3
|
||||
|
||||
def test_modulo(self):
|
||||
assert safe_eval("10 % 3") == 1
|
||||
|
||||
def test_power(self):
|
||||
assert safe_eval("2 ** 10") == 1024
|
||||
|
||||
def test_complex_expression(self):
|
||||
assert safe_eval("(2 + 3) * 4 - 1") == 19
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unary operators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnaryOps:
|
||||
def test_negation(self):
|
||||
assert safe_eval("-5") == -5
|
||||
|
||||
def test_positive(self):
|
||||
assert safe_eval("+5") == 5
|
||||
|
||||
def test_not_true(self):
|
||||
assert safe_eval("not True") is False
|
||||
|
||||
def test_not_false(self):
|
||||
assert safe_eval("not False") is True
|
||||
|
||||
def test_bitwise_invert(self):
|
||||
assert safe_eval("~0") == -1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Comparisons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComparisons:
|
||||
def test_equal(self):
|
||||
assert safe_eval("1 == 1") is True
|
||||
|
||||
def test_not_equal(self):
|
||||
assert safe_eval("1 != 2") is True
|
||||
|
||||
def test_less_than(self):
|
||||
assert safe_eval("1 < 2") is True
|
||||
|
||||
def test_greater_than(self):
|
||||
assert safe_eval("2 > 1") is True
|
||||
|
||||
def test_less_equal(self):
|
||||
assert safe_eval("2 <= 2") is True
|
||||
|
||||
def test_greater_equal(self):
|
||||
assert safe_eval("3 >= 2") is True
|
||||
|
||||
def test_is_none(self):
|
||||
assert safe_eval("x is None", {"x": None}) is True
|
||||
|
||||
def test_is_not_none(self):
|
||||
assert safe_eval("x is not None", {"x": 42}) is True
|
||||
|
||||
def test_in_list(self):
|
||||
assert safe_eval("'a' in x", {"x": ["a", "b", "c"]}) is True
|
||||
|
||||
def test_not_in_list(self):
|
||||
assert safe_eval("'z' not in x", {"x": ["a", "b"]}) is True
|
||||
|
||||
def test_chained_comparison(self):
|
||||
"""Chained comparisons like 1 < x < 10 should work."""
|
||||
assert safe_eval("1 < x < 10", {"x": 5}) is True
|
||||
|
||||
def test_chained_comparison_false(self):
|
||||
assert safe_eval("1 < x < 3", {"x": 5}) is False
|
||||
|
||||
def test_chained_three_way(self):
|
||||
assert safe_eval("0 <= x <= 100", {"x": 50}) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Boolean operators (with short-circuit semantics)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBooleanOps:
|
||||
def test_and_true(self):
|
||||
assert safe_eval("True and True") is True
|
||||
|
||||
def test_and_false(self):
|
||||
assert safe_eval("True and False") is False
|
||||
|
||||
def test_or_true(self):
|
||||
assert safe_eval("False or True") is True
|
||||
|
||||
def test_or_false(self):
|
||||
assert safe_eval("False or False") is False
|
||||
|
||||
def test_and_returns_last_truthy(self):
|
||||
"""Python `and` returns the last value if all truthy."""
|
||||
assert safe_eval("1 and 2 and 3") == 3
|
||||
|
||||
def test_and_returns_first_falsy(self):
|
||||
"""Python `and` returns the first falsy value."""
|
||||
assert safe_eval("1 and 0 and 3") == 0
|
||||
|
||||
def test_or_returns_first_truthy(self):
|
||||
"""Python `or` returns the first truthy value."""
|
||||
assert safe_eval("0 or '' or 42") == 42
|
||||
|
||||
def test_or_returns_last_falsy(self):
|
||||
"""Python `or` returns the last value if all falsy."""
|
||||
assert safe_eval("0 or '' or None") is None
|
||||
|
||||
def test_and_short_circuits(self):
|
||||
"""and should NOT evaluate the right side if left is falsy.
|
||||
|
||||
This is the bug we fixed — previously this would crash with
|
||||
TypeError because all operands were eagerly evaluated.
|
||||
"""
|
||||
# x is None, so `x.get("key")` would crash if evaluated
|
||||
assert safe_eval("x is not None and x.get('key')", {"x": None}) is False
|
||||
|
||||
def test_or_short_circuits(self):
|
||||
"""or should NOT evaluate the right side if left is truthy."""
|
||||
# x is truthy, so the crash-prone right side should never run
|
||||
assert safe_eval("x or y.get('missing')", {"x": "found", "y": {}}) == "found"
|
||||
|
||||
def test_and_guard_pattern_truthy(self):
|
||||
"""Guard pattern: check not None, then access — when value exists."""
|
||||
ctx = {"x": {"key": "value"}}
|
||||
assert safe_eval("x is not None and x.get('key')", ctx) == "value"
|
||||
|
||||
def test_multi_and(self):
|
||||
assert safe_eval("True and True and True") is True
|
||||
|
||||
def test_multi_or(self):
|
||||
assert safe_eval("False or False or True") is True
|
||||
|
||||
def test_mixed_and_or(self):
|
||||
assert safe_eval("True or False and False") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ternary (if/else) expressions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTernary:
|
||||
def test_ternary_true_branch(self):
|
||||
assert safe_eval("'yes' if True else 'no'") == "yes"
|
||||
|
||||
def test_ternary_false_branch(self):
|
||||
assert safe_eval("'yes' if False else 'no'") == "no"
|
||||
|
||||
def test_ternary_with_context(self):
|
||||
assert safe_eval("x * 2 if x > 0 else -x", {"x": 5}) == 10
|
||||
|
||||
def test_ternary_false_with_context(self):
|
||||
assert safe_eval("x * 2 if x > 0 else -x", {"x": -3}) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Variable lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVariables:
|
||||
def test_simple_variable(self):
|
||||
assert safe_eval("x", {"x": 42}) == 42
|
||||
|
||||
def test_string_variable(self):
|
||||
assert safe_eval("name", {"name": "Alice"}) == "Alice"
|
||||
|
||||
def test_dict_variable(self):
|
||||
ctx = {"output": {"status": "ok"}}
|
||||
assert safe_eval("output", ctx) == {"status": "ok"}
|
||||
|
||||
def test_undefined_variable_raises(self):
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("undefined_var")
|
||||
|
||||
def test_multiple_variables(self):
|
||||
assert safe_eval("x + y", {"x": 10, "y": 20}) == 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscript access (indexing)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubscript:
|
||||
def test_dict_subscript(self):
|
||||
assert safe_eval("d['key']", {"d": {"key": "value"}}) == "value"
|
||||
|
||||
def test_list_subscript(self):
|
||||
assert safe_eval("items[0]", {"items": [10, 20, 30]}) == 10
|
||||
|
||||
def test_nested_subscript(self):
|
||||
ctx = {"data": {"users": [{"name": "Alice"}]}}
|
||||
assert safe_eval("data['users'][0]['name']", ctx) == "Alice"
|
||||
|
||||
def test_missing_key_raises(self):
|
||||
with pytest.raises(KeyError):
|
||||
safe_eval("d['missing']", {"d": {}})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attribute access
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAttributeAccess:
|
||||
def test_private_attr_blocked(self):
|
||||
"""Attributes starting with _ must be blocked for security."""
|
||||
with pytest.raises(ValueError, match="private attribute"):
|
||||
safe_eval("x.__class__", {"x": 42})
|
||||
|
||||
def test_dunder_blocked(self):
|
||||
with pytest.raises(ValueError, match="private attribute"):
|
||||
safe_eval("x.__dict__", {"x": {}})
|
||||
|
||||
def test_single_underscore_blocked(self):
|
||||
with pytest.raises(ValueError, match="private attribute"):
|
||||
safe_eval("x._internal", {"x": {}})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Whitelisted function calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFunctionCalls:
|
||||
def test_len(self):
|
||||
assert safe_eval("len(x)", {"x": [1, 2, 3]}) == 3
|
||||
|
||||
def test_int_conversion(self):
|
||||
assert safe_eval("int('42')") == 42
|
||||
|
||||
def test_float_conversion(self):
|
||||
assert safe_eval("float('3.14')") == pytest.approx(3.14)
|
||||
|
||||
def test_str_conversion(self):
|
||||
assert safe_eval("str(42)") == "42"
|
||||
|
||||
def test_bool_conversion(self):
|
||||
assert safe_eval("bool(1)") is True
|
||||
|
||||
def test_abs(self):
|
||||
assert safe_eval("abs(-5)") == 5
|
||||
|
||||
def test_min(self):
|
||||
assert safe_eval("min(3, 1, 2)") == 1
|
||||
|
||||
def test_max(self):
|
||||
assert safe_eval("max(3, 1, 2)") == 3
|
||||
|
||||
def test_sum(self):
|
||||
assert safe_eval("sum(x)", {"x": [1, 2, 3]}) == 6
|
||||
|
||||
def test_round(self):
|
||||
assert safe_eval("round(3.7)") == 4
|
||||
|
||||
def test_all(self):
|
||||
assert safe_eval("all([True, True, True])") is True
|
||||
|
||||
def test_any(self):
|
||||
assert safe_eval("any([False, False, True])") is True
|
||||
|
||||
def test_list_constructor(self):
|
||||
assert safe_eval("list(x)", {"x": (1, 2, 3)}) == [1, 2, 3]
|
||||
|
||||
def test_dict_constructor(self):
|
||||
assert safe_eval("dict(a=1, b=2)") == {"a": 1, "b": 2}
|
||||
|
||||
def test_tuple_constructor(self):
|
||||
assert safe_eval("tuple(x)", {"x": [1, 2]}) == (1, 2)
|
||||
|
||||
def test_set_constructor(self):
|
||||
assert safe_eval("set(x)", {"x": [1, 2, 2, 3]}) == {1, 2, 3}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Whitelisted method calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMethodCalls:
|
||||
def test_dict_get(self):
|
||||
assert safe_eval("d.get('key', 'default')", {"d": {"key": "val"}}) == "val"
|
||||
|
||||
def test_dict_get_missing(self):
|
||||
assert safe_eval("d.get('missing', 'default')", {"d": {}}) == "default"
|
||||
|
||||
def test_dict_keys(self):
|
||||
result = safe_eval("list(d.keys())", {"d": {"a": 1, "b": 2}})
|
||||
assert sorted(result) == ["a", "b"]
|
||||
|
||||
def test_dict_values(self):
|
||||
result = safe_eval("list(d.values())", {"d": {"a": 1, "b": 2}})
|
||||
assert sorted(result) == [1, 2]
|
||||
|
||||
def test_string_lower(self):
|
||||
assert safe_eval("s.lower()", {"s": "HELLO"}) == "hello"
|
||||
|
||||
def test_string_upper(self):
|
||||
assert safe_eval("s.upper()", {"s": "hello"}) == "HELLO"
|
||||
|
||||
def test_string_strip(self):
|
||||
assert safe_eval("s.strip()", {"s": " hi "}) == "hi"
|
||||
|
||||
def test_string_split(self):
|
||||
assert safe_eval("s.split(',')", {"s": "a,b,c"}) == ["a", "b", "c"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: disallowed operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecurity:
|
||||
def test_import_blocked(self):
|
||||
"""__import__ is not in context, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("__import__('os')")
|
||||
|
||||
def test_lambda_blocked(self):
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
safe_eval("(lambda: 1)()")
|
||||
|
||||
def test_comprehension_blocked(self):
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
safe_eval("[x for x in range(10)]")
|
||||
|
||||
def test_assignment_blocked(self):
|
||||
"""Assignment expressions should not parse in eval mode."""
|
||||
with pytest.raises(SyntaxError):
|
||||
safe_eval("x = 5")
|
||||
|
||||
def test_disallowed_function_blocked(self):
|
||||
"""eval is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("eval('1+1')")
|
||||
|
||||
def test_exec_blocked(self):
|
||||
"""exec is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("exec('x=1')")
|
||||
|
||||
def test_type_call_blocked(self):
|
||||
"""type is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("type(42)")
|
||||
|
||||
def test_getattr_builtin_blocked(self):
|
||||
"""getattr is not in safe functions, so NameError is raised."""
|
||||
with pytest.raises(NameError, match="not defined"):
|
||||
safe_eval("getattr(x, '__class__')", {"x": 42})
|
||||
|
||||
def test_empty_expression_raises(self):
|
||||
with pytest.raises(SyntaxError):
|
||||
safe_eval("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Real-world edge condition patterns (from graph executor usage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeConditionPatterns:
|
||||
"""Patterns commonly used in EdgeSpec.condition_expr."""
|
||||
|
||||
def test_output_key_exists_and_not_none(self):
|
||||
ctx = {"output": {"approved_contacts": ["alice@example.com"]}}
|
||||
assert safe_eval("output.get('approved_contacts') is not None", ctx) is True
|
||||
|
||||
def test_output_key_missing(self):
|
||||
ctx = {"output": {}}
|
||||
assert safe_eval("output.get('approved_contacts') is not None", ctx) is False
|
||||
|
||||
def test_output_key_check_with_fallback(self):
|
||||
ctx = {"output": {"redo_extraction": True}}
|
||||
assert safe_eval("output.get('redo_extraction') is not None", ctx) is True
|
||||
|
||||
def test_guard_then_length_check(self):
|
||||
"""Guard pattern: check key exists, then check length."""
|
||||
ctx = {"output": {"results": [1, 2, 3]}}
|
||||
assert (
|
||||
safe_eval(
|
||||
"output.get('results') is not None and len(output['results']) > 0",
|
||||
ctx,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_guard_short_circuits_on_none(self):
|
||||
"""Guard pattern: short-circuit prevents crash on None."""
|
||||
ctx = {"output": {}}
|
||||
assert (
|
||||
safe_eval(
|
||||
"output.get('results') is not None and len(output['results']) > 0",
|
||||
ctx,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_success_flag_check(self):
|
||||
ctx = {"output": {"success": True}, "memory": {"attempts": 2}}
|
||||
assert safe_eval("output.get('success') == True", ctx) is True
|
||||
|
||||
def test_memory_threshold(self):
|
||||
ctx = {"memory": {"score": 0.85}}
|
||||
assert safe_eval("memory.get('score', 0) >= 0.8", ctx) is True
|
||||
|
||||
def test_string_contains_check(self):
|
||||
ctx = {"output": {"status": "completed_with_warnings"}}
|
||||
assert safe_eval("'completed' in output.get('status', '')", ctx) is True
|
||||
|
||||
def test_fallback_chain(self):
|
||||
"""or-chain for fallback values."""
|
||||
ctx = {"output": {}}
|
||||
result = safe_eval(
|
||||
"output.get('primary') or output.get('secondary') or 'default'",
|
||||
ctx,
|
||||
)
|
||||
assert result == "default"
|
||||
|
||||
def test_no_context_needed(self):
|
||||
"""Some edges use constant expressions."""
|
||||
assert safe_eval("True") is True
|
||||
assert safe_eval("1 == 1") is True
|
||||
@@ -0,0 +1,142 @@
|
||||
"""Tests for AS-9: Skill directory allowlisting in file-read tool interception."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.provider import ToolResult
|
||||
|
||||
|
||||
def _make_tool_call_event(tool_name: str, path: str):
|
||||
"""Build a minimal ToolCallEvent-like object."""
|
||||
tc = MagicMock()
|
||||
tc.tool_use_id = "tc-1"
|
||||
tc.tool_name = tool_name
|
||||
tc.tool_input = {"path": path}
|
||||
return tc
|
||||
|
||||
|
||||
def _make_node(skill_dirs: list[str]):
|
||||
"""Build a minimal EventLoopNode with skill_dirs set."""
|
||||
from framework.graph.event_loop_node import EventLoopNode
|
||||
|
||||
mock_result = ToolResult(tool_use_id="tc-1", content="from-executor")
|
||||
node = EventLoopNode(tool_executor=MagicMock(return_value=mock_result))
|
||||
node._skill_dirs = skill_dirs
|
||||
return node
|
||||
|
||||
|
||||
class TestSkillFileReadInterception:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reads_file_in_skill_dir(self, tmp_path):
|
||||
"""File under a skill dir is read directly, bypassing the executor."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
script = skill_dir / "scripts" / "run.py"
|
||||
script.parent.mkdir()
|
||||
script.write_text("print('hello')")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(script))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "print('hello')"
|
||||
assert not result.is_error
|
||||
node._tool_executor.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_md_read_marked_as_skill_content(self, tmp_path):
|
||||
"""Reading SKILL.md sets is_skill_content=True for AS-10 protection."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
skill_md.write_text("---\nname: my-skill\ndescription: Test\n---\nInstructions.")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(skill_md))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.is_skill_content is True
|
||||
assert not result.is_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_skill_md_resource_not_marked(self, tmp_path):
|
||||
"""Bundled resource (not SKILL.md) is NOT marked as skill_content."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
ref = skill_dir / "references" / "api.md"
|
||||
ref.parent.mkdir()
|
||||
ref.write_text("# API Reference")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("load_data", str(ref))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.is_skill_content is False
|
||||
assert not result.is_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_outside_skill_dir_goes_to_executor(self, tmp_path):
|
||||
"""Path outside skill dirs is passed through to the executor unchanged."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
other_file = tmp_path / "other" / "file.txt"
|
||||
other_file.parent.mkdir()
|
||||
other_file.write_text("other content")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(other_file))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "from-executor"
|
||||
node._tool_executor.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_skill_dirs_goes_to_executor(self, tmp_path):
|
||||
"""When skill_dirs is empty, all tool calls go to executor."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
script = skill_dir / "scripts" / "run.py"
|
||||
script.parent.mkdir()
|
||||
script.write_text("print('hello')")
|
||||
|
||||
node = _make_node([])
|
||||
tc = _make_tool_call_event("view_file", str(script))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "from-executor"
|
||||
node._tool_executor.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_file_returns_error(self, tmp_path):
|
||||
"""Non-existent file under skill dir returns is_error=True."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
missing = skill_dir / "scripts" / "missing.py"
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(missing))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.is_error is True
|
||||
assert "Could not read skill resource" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_file_read_tool_goes_to_executor(self, tmp_path):
|
||||
"""Non file-read tools (e.g. web_search) bypass the interceptor."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("web_search", str(skill_dir / "SKILL.md"))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "from-executor"
|
||||
node._tool_executor.assert_called_once()
|
||||
@@ -69,7 +69,13 @@ class TestSkillCatalog:
|
||||
|
||||
def test_to_prompt_xml_generation(self):
|
||||
skills = [
|
||||
_make_skill("alpha", "Alpha skill", "project", location="/p/alpha/SKILL.md"),
|
||||
_make_skill(
|
||||
"alpha",
|
||||
"Alpha skill",
|
||||
"project",
|
||||
location="/p/alpha/SKILL.md",
|
||||
base_dir="/p/alpha",
|
||||
),
|
||||
_make_skill("beta", "Beta skill", "user", location="/u/beta/SKILL.md"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
@@ -81,6 +87,7 @@ class TestSkillCatalog:
|
||||
assert "<name>beta</name>" in prompt
|
||||
assert "<description>Alpha skill</description>" in prompt
|
||||
assert "<location>/p/alpha/SKILL.md</location>" in prompt
|
||||
assert "<base_dir>/p/alpha</base_dir>" in prompt
|
||||
|
||||
def test_to_prompt_sorted_by_name(self):
|
||||
skills = [
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Tests for AS-10: Activated skill content protected from context pruning."""
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import Message, NodeConversation
|
||||
|
||||
|
||||
def _make_conversation() -> NodeConversation:
|
||||
conv = NodeConversation.__new__(NodeConversation)
|
||||
conv._messages = []
|
||||
conv._next_seq = 0
|
||||
conv._current_phase = None
|
||||
conv._store = None
|
||||
return conv
|
||||
|
||||
|
||||
async def _add_tool_msg(conv: NodeConversation, content: str, **kwargs) -> Message:
|
||||
return await conv.add_tool_result(
|
||||
tool_use_id=f"tc-{conv._next_seq}",
|
||||
content=content,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class TestSkillContentProtection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_skill_content_flag_persists(self):
|
||||
"""Message created with is_skill_content=True retains the flag."""
|
||||
conv = _make_conversation()
|
||||
msg = await _add_tool_msg(conv, "skill instructions", is_skill_content=True)
|
||||
assert msg.is_skill_content is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_message_not_marked(self):
|
||||
"""Normal tool result messages are not marked as skill content."""
|
||||
conv = _make_conversation()
|
||||
msg = await _add_tool_msg(conv, "some tool output")
|
||||
assert msg.is_skill_content is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_content_survives_prune(self):
|
||||
"""Skill content messages are skipped by prune_old_tool_results."""
|
||||
conv = _make_conversation()
|
||||
|
||||
# Add many regular tool results to push over prune threshold
|
||||
for _ in range(30):
|
||||
await _add_tool_msg(conv, "x" * 500) # ~125 tokens each
|
||||
|
||||
# Add a skill content message
|
||||
skill_msg = await _add_tool_msg(
|
||||
conv,
|
||||
"## Deep Research\n" + "instructions " * 200,
|
||||
is_skill_content=True,
|
||||
)
|
||||
|
||||
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
|
||||
|
||||
assert pruned > 0, "Expected some messages to be pruned"
|
||||
# Find the skill message — it must not be pruned
|
||||
matching = [m for m in conv._messages if m.seq == skill_msg.seq]
|
||||
assert matching, "Skill content message was removed"
|
||||
assert not matching[0].content.startswith("[Pruned tool result")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_content_can_be_pruned(self):
|
||||
"""Regular tool results are still pruned when over threshold."""
|
||||
conv = _make_conversation()
|
||||
|
||||
for _ in range(20):
|
||||
await _add_tool_msg(conv, "regular tool output " * 50)
|
||||
|
||||
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
|
||||
|
||||
assert pruned > 0, "Expected regular messages to be pruned"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_messages_also_protected(self):
|
||||
"""Existing is_error protection still works alongside is_skill_content."""
|
||||
conv = _make_conversation()
|
||||
|
||||
for _ in range(20):
|
||||
await _add_tool_msg(conv, "output " * 100)
|
||||
|
||||
err_msg = await _add_tool_msg(conv, "tool failed", is_error=True)
|
||||
|
||||
await conv.prune_old_tool_results(protect_tokens=200, min_prune_tokens=50)
|
||||
|
||||
matching = [m for m in conv._messages if m.seq == err_msg.seq]
|
||||
assert matching
|
||||
assert not matching[0].content.startswith("[Pruned tool result")
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Tests for skill system structured error codes and diagnostics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from framework.skills.skill_errors import (
|
||||
SkillError,
|
||||
SkillErrorCode,
|
||||
log_skill_error,
|
||||
)
|
||||
|
||||
|
||||
class TestSkillErrorCode:
|
||||
def test_all_codes_defined(self):
|
||||
codes = {e.value for e in SkillErrorCode}
|
||||
assert "SKILL_NOT_FOUND" in codes
|
||||
assert "SKILL_PARSE_ERROR" in codes
|
||||
assert "SKILL_ACTIVATION_FAILED" in codes
|
||||
assert "SKILL_MISSING_DESCRIPTION" in codes
|
||||
assert "SKILL_YAML_FIXUP" in codes
|
||||
assert "SKILL_NAME_MISMATCH" in codes
|
||||
assert "SKILL_COLLISION" in codes
|
||||
|
||||
|
||||
class TestSkillError:
|
||||
def test_code_stored(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what="Skill 'my-skill' not found",
|
||||
why="Not in catalog",
|
||||
fix="Check discovery paths",
|
||||
)
|
||||
assert err.code == SkillErrorCode.SKILL_NOT_FOUND
|
||||
|
||||
def test_message_format(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_MISSING_DESCRIPTION,
|
||||
what="Missing description in '/path/SKILL.md'",
|
||||
why="The description field is absent",
|
||||
fix="Add a description field to the frontmatter",
|
||||
)
|
||||
expected = (
|
||||
"[SKILL_MISSING_DESCRIPTION]\n"
|
||||
"What failed: Missing description in '/path/SKILL.md'\n"
|
||||
"Why: The description field is absent\n"
|
||||
"Fix: Add a description field to the frontmatter"
|
||||
)
|
||||
assert str(err) == expected
|
||||
|
||||
def test_is_exception(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what="Parse failed",
|
||||
why="Invalid YAML",
|
||||
fix="Fix the YAML",
|
||||
)
|
||||
assert isinstance(err, Exception)
|
||||
|
||||
def test_what_why_fix_attributes(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_COLLISION,
|
||||
what="Name collision",
|
||||
why="Two skills share the same name",
|
||||
fix="Rename one skill directory",
|
||||
)
|
||||
assert err.what == "Name collision"
|
||||
assert err.why == "Two skills share the same name"
|
||||
assert err.fix == "Rename one skill directory"
|
||||
|
||||
|
||||
class TestLogSkillError:
|
||||
def test_emits_log(self, caplog):
|
||||
test_logger = logging.getLogger("test_skill")
|
||||
with caplog.at_level(logging.ERROR, logger="test_skill"):
|
||||
log_skill_error(
|
||||
test_logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what="Invalid SKILL.md at '/path'",
|
||||
why="Empty file",
|
||||
fix="Add content",
|
||||
)
|
||||
assert "SKILL_PARSE_ERROR" in caplog.text
|
||||
|
||||
def test_warning_level(self, caplog):
|
||||
test_logger = logging.getLogger("test_skill_warn")
|
||||
with caplog.at_level(logging.WARNING, logger="test_skill_warn"):
|
||||
log_skill_error(
|
||||
test_logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_YAML_FIXUP,
|
||||
what="Auto-fixed YAML",
|
||||
why="Unquoted colons",
|
||||
fix="Quote values",
|
||||
)
|
||||
assert "SKILL_YAML_FIXUP" in caplog.text
|
||||
|
||||
def test_message_contains_all_parts(self, caplog):
|
||||
test_logger = logging.getLogger("test_skill_parts")
|
||||
with caplog.at_level(logging.ERROR, logger="test_skill_parts"):
|
||||
log_skill_error(
|
||||
test_logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what="Skill not found",
|
||||
why="Not discovered",
|
||||
fix="Check paths",
|
||||
)
|
||||
assert "Skill not found" in caplog.text
|
||||
assert "Not discovered" in caplog.text
|
||||
assert "Check paths" in caplog.text
|
||||
|
||||
|
||||
class TestSkillErrorInParser:
|
||||
def test_missing_description_returns_none(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
skill_dir = tmp_path / "no-desc"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.\n", encoding="utf-8")
|
||||
result = parse_skill_md(skill_dir / "SKILL.md")
|
||||
assert result is None
|
||||
|
||||
def test_empty_file_returns_none(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
skill_dir = tmp_path / "empty"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text("", encoding="utf-8")
|
||||
result = parse_skill_md(skill_dir / "SKILL.md")
|
||||
assert result is None
|
||||
|
||||
def test_nonexistent_returns_none(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
result = parse_skill_md(tmp_path / "ghost" / "SKILL.md")
|
||||
assert result is None
|
||||
|
||||
def test_yaml_fixup_still_parses(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
skill_dir = tmp_path / "colon-test"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: colon-test\ndescription: Use for: research\n---\nBody.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = parse_skill_md(skill_dir / "SKILL.md")
|
||||
assert result is not None
|
||||
assert "research" in result.description
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Tests for AS-6 skill resource loading support.
|
||||
|
||||
Covers:
|
||||
- <base_dir> element in catalog XML
|
||||
- allowlisted_dirs property reflects trusted skill base directories
|
||||
- skill_dirs propagation to NodeContext
|
||||
"""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
|
||||
def _make_skill(
|
||||
name: str,
|
||||
base_dir: str,
|
||||
source_scope: str = "project",
|
||||
) -> ParsedSkill:
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=f"Skill {name}",
|
||||
location=f"{base_dir}/SKILL.md",
|
||||
base_dir=base_dir,
|
||||
source_scope=source_scope,
|
||||
body="Instructions.",
|
||||
)
|
||||
|
||||
|
||||
class TestSkillResourceBaseDir:
|
||||
def test_base_dir_in_xml(self):
|
||||
"""Each community skill entry should expose its base_dir in the catalog XML."""
|
||||
skill = _make_skill("deploy", "/project/.hive/skills/deploy")
|
||||
catalog = SkillCatalog([skill])
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "<base_dir>/project/.hive/skills/deploy</base_dir>" in prompt
|
||||
|
||||
def test_base_dir_xml_escaped(self):
|
||||
"""base_dir with XML-special chars should be escaped."""
|
||||
skill = _make_skill("s", "/path/with <&> chars")
|
||||
catalog = SkillCatalog([skill])
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "<base_dir>/path/with <&> chars</base_dir>" in prompt
|
||||
|
||||
def test_base_dir_absent_for_framework_skills(self):
|
||||
"""Framework-scope skills are filtered from the catalog, so no base_dir either."""
|
||||
skill = _make_skill("fw", "/hive/_default_skills/fw", source_scope="framework")
|
||||
catalog = SkillCatalog([skill])
|
||||
assert catalog.to_prompt() == ""
|
||||
|
||||
def test_allowlisted_dirs_matches_skills(self):
|
||||
"""allowlisted_dirs returns all skill base_dirs including framework ones."""
|
||||
skills = [
|
||||
_make_skill("a", "/skills/a", "project"),
|
||||
_make_skill("b", "/skills/b", "user"),
|
||||
_make_skill("c", "/skills/c", "framework"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
dirs = catalog.allowlisted_dirs
|
||||
|
||||
assert "/skills/a" in dirs
|
||||
assert "/skills/b" in dirs
|
||||
assert "/skills/c" in dirs
|
||||
|
||||
def test_allowlisted_dirs_empty_catalog(self):
|
||||
assert SkillCatalog().allowlisted_dirs == []
|
||||
|
||||
|
||||
class TestSkillDirsPropagation:
|
||||
def _make_ctx(self, **kwargs):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.graph.node import NodeContext
|
||||
|
||||
return NodeContext(
|
||||
runtime=MagicMock(),
|
||||
node_id="n",
|
||||
node_spec=MagicMock(),
|
||||
memory={},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def test_node_context_skill_dirs_default(self):
|
||||
"""NodeContext.skill_dirs defaults to empty list."""
|
||||
ctx = self._make_ctx()
|
||||
assert ctx.skill_dirs == []
|
||||
|
||||
def test_node_context_skill_dirs_set(self):
|
||||
"""NodeContext.skill_dirs can be populated."""
|
||||
dirs = ["/skills/a", "/skills/b"]
|
||||
ctx = self._make_ctx(skill_dirs=dirs)
|
||||
assert ctx.skill_dirs == dirs
|
||||
@@ -0,0 +1,471 @@
|
||||
"""Tests for skill trust gating (AS-13)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.trust import (
|
||||
ProjectTrustClassification,
|
||||
ProjectTrustDetector,
|
||||
TrustedRepoStore,
|
||||
TrustGate,
|
||||
_is_localhost_remote,
|
||||
_normalize_remote_url,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_skill(name: str = "test-skill", scope: str = "project") -> ParsedSkill:
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description="Test skill",
|
||||
location=f"/fake/{name}/SKILL.md",
|
||||
base_dir=f"/fake/{name}",
|
||||
source_scope=scope,
|
||||
body="Test skill instructions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _normalize_remote_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeRemoteUrl:
|
||||
def test_ssh_scp_format(self):
|
||||
assert _normalize_remote_url("git@github.com:org/repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_https_format(self):
|
||||
assert _normalize_remote_url("https://github.com/org/repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_https_no_dot_git(self):
|
||||
assert _normalize_remote_url("https://github.com/org/repo") == "github.com/org/repo"
|
||||
|
||||
def test_ssh_url_format(self):
|
||||
assert _normalize_remote_url("ssh://git@github.com/org/repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_lowercased(self):
|
||||
assert _normalize_remote_url("git@GitHub.COM:Org/Repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_trailing_slash_stripped(self):
|
||||
assert _normalize_remote_url("https://github.com/org/repo/") == "github.com/org/repo"
|
||||
|
||||
def test_gitlab(self):
|
||||
assert _normalize_remote_url("git@gitlab.com:team/project.git") == "gitlab.com/team/project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_localhost_remote
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsLocalhostRemote:
|
||||
def test_localhost_https(self):
|
||||
assert _is_localhost_remote("http://localhost/org/repo")
|
||||
|
||||
def test_127_0_0_1(self):
|
||||
assert _is_localhost_remote("https://127.0.0.1/repo")
|
||||
|
||||
def test_github_not_local(self):
|
||||
assert not _is_localhost_remote("https://github.com/org/repo")
|
||||
|
||||
def test_scp_localhost(self):
|
||||
assert _is_localhost_remote("git@localhost:org/repo")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrustedRepoStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrustedRepoStore:
|
||||
def test_empty_store_is_not_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
assert not store.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_trust_and_lookup(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
store.trust("github.com/org/repo", project_path="/some/path")
|
||||
assert store.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_revoke(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
store.trust("github.com/org/repo")
|
||||
assert store.revoke("github.com/org/repo")
|
||||
assert not store.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_revoke_nonexistent_returns_false(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
assert not store.revoke("github.com/nobody/nowhere")
|
||||
|
||||
def test_persists_across_instances(self, tmp_path):
|
||||
path = tmp_path / "trusted.json"
|
||||
store1 = TrustedRepoStore(path)
|
||||
store1.trust("github.com/org/repo")
|
||||
|
||||
store2 = TrustedRepoStore(path)
|
||||
assert store2.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_atomic_write(self, tmp_path):
|
||||
"""Save must not leave a .tmp file behind."""
|
||||
path = tmp_path / "trusted.json"
|
||||
store = TrustedRepoStore(path)
|
||||
store.trust("github.com/org/repo")
|
||||
assert not (tmp_path / "trusted.tmp").exists()
|
||||
assert path.exists()
|
||||
|
||||
def test_corrupted_json_recovers_gracefully(self, tmp_path):
|
||||
path = tmp_path / "trusted.json"
|
||||
path.write_text("{not valid json{{", encoding="utf-8")
|
||||
store = TrustedRepoStore(path)
|
||||
assert not store.is_trusted("github.com/any/repo") # no crash
|
||||
|
||||
def test_json_schema(self, tmp_path):
|
||||
path = tmp_path / "trusted.json"
|
||||
store = TrustedRepoStore(path)
|
||||
store.trust("github.com/org/repo", project_path="/work/repo")
|
||||
data = json.loads(path.read_text())
|
||||
assert data["version"] == 1
|
||||
assert data["entries"][0]["repo_key"] == "github.com/org/repo"
|
||||
assert "added_at" in data["entries"][0]
|
||||
|
||||
def test_list_entries(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
store.trust("github.com/a/b")
|
||||
store.trust("github.com/c/d")
|
||||
entries = store.list_entries()
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ProjectTrustDetector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectTrustDetector:
|
||||
def test_none_project_dir_always_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
cls, _ = det.classify(None)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_nonexistent_dir_always_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
cls, _ = det.classify(tmp_path / "nonexistent")
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_no_git_dir_always_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_no_remote_always_trusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
# git command returns non-zero (no remote)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=1, stdout="")
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_localhost_remote_always_trusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="http://localhost/org/repo.git\n"
|
||||
)
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_trusted_by_store(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
store.trust("github.com/trusted/repo")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="git@github.com:trusted/repo.git\n"
|
||||
)
|
||||
cls, key = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.TRUSTED_BY_USER
|
||||
assert key == "github.com/trusted/repo"
|
||||
|
||||
def test_unknown_remote_untrusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
cls, key = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.UNTRUSTED
|
||||
assert key == "github.com/stranger/repo"
|
||||
|
||||
def test_own_remotes_env_var(self, tmp_path, monkeypatch):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
monkeypatch.setenv("HIVE_OWN_REMOTES", "github.com/myorg/*")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="git@github.com:myorg/myrepo.git\n"
|
||||
)
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_git_timeout_treated_as_trusted(self, tmp_path):
|
||||
import subprocess
|
||||
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run", side_effect=subprocess.TimeoutExpired("git", 3)):
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_git_not_found_treated_as_trusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run", side_effect=FileNotFoundError("git not found")):
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrustGate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrustGate:
|
||||
def test_framework_scope_always_passes(self, tmp_path):
|
||||
skill = make_skill("fw-skill", "framework")
|
||||
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
|
||||
result = gate.filter_and_gate([skill], project_dir=None)
|
||||
assert any(s.name == "fw-skill" for s in result)
|
||||
|
||||
def test_user_scope_always_passes(self, tmp_path):
|
||||
skill = make_skill("user-skill", "user")
|
||||
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
|
||||
result = gate.filter_and_gate([skill], project_dir=None)
|
||||
assert any(s.name == "user-skill" for s in result)
|
||||
|
||||
def test_no_project_skills_returns_early(self, tmp_path):
|
||||
"""When there are no project-scope skills, trust detection is skipped."""
|
||||
fw = make_skill("fw", "framework")
|
||||
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
|
||||
result = gate.filter_and_gate([fw], project_dir=tmp_path)
|
||||
assert result == [fw]
|
||||
|
||||
def test_trusted_project_skills_pass(self, tmp_path):
|
||||
"""Project skills from a trusted repo pass through."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
store.trust("github.com/trusted/repo")
|
||||
skill = make_skill("proj-skill", "project")
|
||||
gate = TrustGate(store=store, interactive=False)
|
||||
with patch("subprocess.run") as m:
|
||||
m.return_value = MagicMock(returncode=0, stdout="git@github.com:trusted/repo.git\n")
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "proj-skill" for s in result)
|
||||
|
||||
def test_untrusted_headless_skips_and_logs(self, tmp_path, caplog):
|
||||
"""In non-interactive mode, untrusted project skills are skipped."""
|
||||
import logging
|
||||
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("evil-skill", "project")
|
||||
gate = TrustGate(store=store, interactive=False)
|
||||
with patch("subprocess.run") as m:
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/evil.git\n"
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert not any(s.name == "evil-skill" for s in result)
|
||||
assert "untrusted" in caplog.text.lower() or "skipping" in caplog.text.lower()
|
||||
|
||||
def test_interactive_consent_session_only(self, tmp_path):
|
||||
"""Option 1 (session only) includes skills without writing to store."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("session-skill", "project")
|
||||
outputs = []
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=outputs.append,
|
||||
input_fn=lambda _: "1", # trust this session
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "session-skill" for s in result)
|
||||
# Must NOT persist to trusted store
|
||||
assert not store.is_trusted("github.com/stranger/repo")
|
||||
|
||||
def test_interactive_consent_permanent(self, tmp_path):
|
||||
"""Option 2 (permanent) includes skills and persists to trusted store."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("perm-skill", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: "2", # trust permanently
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "perm-skill" for s in result)
|
||||
assert store.is_trusted("github.com/stranger/repo")
|
||||
|
||||
def test_interactive_consent_deny(self, tmp_path):
|
||||
"""Option 3 (deny) excludes project skills."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("bad-skill", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: "3", # deny
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert not any(s.name == "bad-skill" for s in result)
|
||||
|
||||
def test_env_var_override_trusts_all(self, tmp_path, monkeypatch):
|
||||
"""HIVE_TRUST_PROJECT_SKILLS=1 bypasses gating entirely."""
|
||||
monkeypatch.setenv("HIVE_TRUST_PROJECT_SKILLS", "1")
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("env-skill", "project")
|
||||
gate = TrustGate(store=store, interactive=False)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "env-skill" for s in result)
|
||||
|
||||
def test_keyboard_interrupt_treated_as_deny(self, tmp_path):
|
||||
"""Ctrl-C during consent prompt should deny cleanly."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("interrupted-skill", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: (_ for _ in ()).throw(KeyboardInterrupt()),
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert not any(s.name == "interrupted-skill" for s in result)
|
||||
|
||||
def test_security_notice_shown_once(self, tmp_path, monkeypatch):
|
||||
"""Security notice (NFR-5) should be shown the first time only."""
|
||||
# Use a temp sentinel path
|
||||
sentinel = tmp_path / ".skill_trust_notice_shown"
|
||||
monkeypatch.setattr("framework.skills.trust._NOTICE_SENTINEL_PATH", sentinel)
|
||||
assert not sentinel.exists()
|
||||
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("notice-skill", "project")
|
||||
output_lines: list[str] = []
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=output_lines.append,
|
||||
input_fn=lambda _: "3",
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
|
||||
assert sentinel.exists()
|
||||
assert any("Security notice" in line for line in output_lines)
|
||||
|
||||
# Second run should NOT show the notice again
|
||||
output_lines.clear()
|
||||
skill2 = make_skill("notice-skill-2", "project")
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
gate.filter_and_gate([skill2], project_dir=tmp_path)
|
||||
|
||||
assert not any("Security notice" in line for line in output_lines)
|
||||
|
||||
def test_mixed_scopes_only_project_gated(self, tmp_path, monkeypatch):
|
||||
"""Framework and user skills should pass through even if project skills are denied."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
fw_skill = make_skill("fw", "framework")
|
||||
user_skill = make_skill("usr", "user")
|
||||
proj_skill = make_skill("proj", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: "3", # deny project skills
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([fw_skill, user_skill, proj_skill], project_dir=tmp_path)
|
||||
names = {s.name for s in result}
|
||||
assert "fw" in names
|
||||
assert "usr" in names
|
||||
assert "proj" not in names
|
||||
@@ -8,6 +8,7 @@ could cause a json.JSONDecodeError and crash execution.
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
@@ -91,3 +92,117 @@ def test_discover_from_module_handles_empty_content(tmp_path):
|
||||
result = registered.executor({})
|
||||
assert isinstance(result, dict)
|
||||
assert result == {}
|
||||
|
||||
|
||||
class _RegistryFakeClient:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.connect_calls = 0
|
||||
self.disconnect_calls = 0
|
||||
|
||||
def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.disconnect_calls += 1
|
||||
|
||||
def list_tools(self):
|
||||
return [
|
||||
SimpleNamespace(
|
||||
name="pooled_tool",
|
||||
description="Tool from MCP",
|
||||
input_schema={"type": "object", "properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
def call_tool(self, tool_name, arguments):
|
||||
return [{"text": f"{tool_name}:{arguments}"}]
|
||||
|
||||
|
||||
def test_register_mcp_server_uses_connection_manager_when_enabled(monkeypatch):
|
||||
registry = ToolRegistry()
|
||||
client = _RegistryFakeClient(SimpleNamespace(name="shared"))
|
||||
manager_calls: list[tuple[str, str]] = []
|
||||
|
||||
class FakeManager:
|
||||
def acquire(self, config):
|
||||
manager_calls.append(("acquire", config.name))
|
||||
client.config = config
|
||||
return client
|
||||
|
||||
def release(self, server_name: str) -> None:
|
||||
manager_calls.append(("release", server_name))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
|
||||
lambda: FakeManager(),
|
||||
)
|
||||
|
||||
count = registry.register_mcp_server(
|
||||
{"name": "shared", "transport": "stdio", "command": "echo"},
|
||||
use_connection_manager=True,
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
assert manager_calls == [("acquire", "shared")]
|
||||
|
||||
registry.cleanup()
|
||||
|
||||
assert manager_calls == [("acquire", "shared"), ("release", "shared")]
|
||||
assert client.disconnect_calls == 0
|
||||
|
||||
|
||||
def test_register_mcp_server_defaults_to_connection_manager(monkeypatch):
|
||||
"""Default behavior uses the connection manager (reuse enabled by default)."""
|
||||
registry = ToolRegistry()
|
||||
created_clients: list[_RegistryFakeClient] = []
|
||||
|
||||
def fake_client_factory(config):
|
||||
client = _RegistryFakeClient(config)
|
||||
created_clients.append(client)
|
||||
return client
|
||||
|
||||
class FakeManager:
|
||||
def acquire(self, config):
|
||||
return fake_client_factory(config)
|
||||
|
||||
def release(self, server_name):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
|
||||
lambda: FakeManager(),
|
||||
)
|
||||
|
||||
count = registry.register_mcp_server(
|
||||
{"name": "direct", "transport": "stdio", "command": "echo"},
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
assert len(created_clients) == 1
|
||||
|
||||
|
||||
def test_register_mcp_server_direct_client_when_manager_disabled(monkeypatch):
|
||||
"""When use_connection_manager=False, a direct MCPClient is created."""
|
||||
registry = ToolRegistry()
|
||||
created_clients: list[_RegistryFakeClient] = []
|
||||
|
||||
def fake_client_factory(config):
|
||||
client = _RegistryFakeClient(config)
|
||||
created_clients.append(client)
|
||||
return client
|
||||
|
||||
monkeypatch.setattr("framework.runner.mcp_client.MCPClient", fake_client_factory)
|
||||
|
||||
count = registry.register_mcp_server(
|
||||
{"name": "direct", "transport": "stdio", "command": "echo"},
|
||||
use_connection_manager=False,
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
assert len(created_clients) == 1
|
||||
assert created_clients[0].connect_calls == 1
|
||||
|
||||
registry.cleanup()
|
||||
|
||||
assert created_clients[0].disconnect_calls == 1
|
||||
|
||||
@@ -157,7 +157,7 @@ All bounty types open in parallel. Contributors self-select. Daily progress upda
|
||||
PR merged with bounty:* label
|
||||
→ GitHub Action runs bounty-tracker.ts
|
||||
→ Calculates points from label
|
||||
→ Resolves GitHub → Discord ID via contributors.yml
|
||||
→ Resolves GitHub → Discord ID via MongoDB (hive.contributors)
|
||||
→ Pushes XP to Lurkr API
|
||||
→ Posts notification to #integrations-announcements
|
||||
```
|
||||
@@ -166,7 +166,7 @@ See the [Setup Guide](setup-guide.md) for full configuration (Lurkr, webhooks, s
|
||||
|
||||
### Identity Linking
|
||||
|
||||
Contributors link GitHub ↔ Discord by opening a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue. A GitHub Action auto-adds them to `contributors.yml` and closes the issue.
|
||||
Contributors link GitHub ↔ Discord by running `/link-github` in Discord. The bot verifies ownership via a public gist, then stores the mapping in MongoDB.
|
||||
|
||||
Without this link, bounties are still tracked but Lurkr can't push XP to your Discord account.
|
||||
|
||||
@@ -181,7 +181,7 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
|
||||
| Agent Builder role | Lurkr bot | Auto-assigned at level 5 |
|
||||
| OSS Contributor role | Lurkr bot | Auto-assigned at level 15 |
|
||||
| Core Contributor role | Maintainer | Manual (involves money) |
|
||||
| Identity linking | contributors.yml | PR-based, reviewed by maintainers |
|
||||
| Identity linking | Discord bot → MongoDB | `/link-github` command with gist verification |
|
||||
|
||||
## Guides
|
||||
|
||||
@@ -203,4 +203,4 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
|
||||
- `.github/workflows/weekly-leaderboard.yml` — Monday leaderboard post
|
||||
- `scripts/bounty-tracker.ts` — Point calculation, Lurkr API, Discord formatting
|
||||
- `scripts/setup-bounty-labels.sh` — One-time label setup
|
||||
- `contributors.yml` — GitHub ↔ Discord identity mapping
|
||||
- MongoDB `hive.contributors` — GitHub ↔ Discord identity mapping (managed by Discord bot)
|
||||
|
||||
@@ -6,9 +6,7 @@ Earn XP, Discord roles, and eventually real money by contributing to the Aden ag
|
||||
|
||||
### 1. Link your GitHub and Discord
|
||||
|
||||
Open a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue — just paste your Discord ID and submit. A GitHub Action will automatically add you to `contributors.yml` and close the issue.
|
||||
|
||||
To find your Discord ID: Discord Settings > Advanced > Enable **Developer Mode**, then right-click your name > **Copy User ID**.
|
||||
Run `/link-github your-github-username` in Discord. The bot will give you a verification code — create a public gist with that code, then run `/verify`. Done.
|
||||
|
||||
Without this link, Lurkr can't push XP to your Discord account.
|
||||
|
||||
@@ -154,7 +152,7 @@ A: Yes. Most services have free tiers. The bounty issue links to where you get t
|
||||
A: Contribute consistently across different bounty types for 4+ weeks. Maintainers will nominate you.
|
||||
|
||||
**Q: What if I haven't linked my Discord yet?**
|
||||
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Add yourself to `contributors.yml`.
|
||||
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Run `/link-github` in Discord.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
@@ -104,6 +104,8 @@ Repo Settings > Secrets and variables > Actions:
|
||||
| `DISCORD_BOUNTY_WEBHOOK_URL` | Webhook URL from Step 5 |
|
||||
| `LURKR_API_KEY` | Lurkr API key from Step 4f |
|
||||
| `LURKR_GUILD_ID` | Your Discord server ID\* |
|
||||
| `BOT_API_URL` | Discord bot API URL |
|
||||
| `BOT_API_KEY` | Discord bot API key |
|
||||
|
||||
\*Enable Developer Mode in Discord, right-click server name > Copy Server ID.
|
||||
|
||||
@@ -146,12 +148,12 @@ powerbi, redis
|
||||
- [ ] All 3 GitHub secrets added
|
||||
- [ ] Both workflows enabled (`bounty-completed.yml`, `weekly-leaderboard.yml`)
|
||||
- [ ] Test PR + merge triggers Discord notification
|
||||
- [ ] `contributors.yml` exists at repo root
|
||||
- [ ] MongoDB `hive.contributors` collection accessible
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**No Discord message:** Check `DISCORD_BOUNTY_WEBHOOK_URL` secret and Action logs.
|
||||
|
||||
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor is in `contributors.yml`, check Action logs for `Lurkr XP push failed`.
|
||||
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor has run `/link-github` in Discord, check Action logs for `Lurkr XP push failed`.
|
||||
|
||||
**Role not assigned:** Verify role rewards in the Lurkr dashboard or via `/config set`. Lurkr's role must be above the roles it assigns in server hierarchy.
|
||||
|
||||
@@ -0,0 +1,290 @@
|
||||
# Agent Skills User Guide
|
||||
|
||||
This guide covers how to use, create, and manage Agent Skills in the Hive framework. Agent Skills follow the open [Agent Skills standard](https://agentskills.io) — skills written for Claude Code, Cursor, or other compatible agents work in Hive unchanged.
|
||||
|
||||
## What are skills?
|
||||
|
||||
Skills are folders containing a `SKILL.md` file that teaches an agent how to perform a specific task. They can also bundle scripts, templates, and reference materials. Skills are loaded on demand — the agent sees a lightweight catalog at startup and pulls in full instructions only when relevant.
|
||||
|
||||
## Quick start
|
||||
|
||||
### Install a skill
|
||||
|
||||
Drop a skill folder into one of the discovery directories:
|
||||
|
||||
```bash
|
||||
# Project-level (shared with the repo)
|
||||
mkdir -p .hive/skills/my-skill
|
||||
cat > .hive/skills/my-skill/SKILL.md << 'EOF'
|
||||
---
|
||||
name: my-skill
|
||||
description: Does X when the user asks about Y.
|
||||
---
|
||||
|
||||
# My Skill
|
||||
|
||||
Step-by-step instructions for the agent...
|
||||
EOF
|
||||
```
|
||||
|
||||
The agent will discover it automatically on the next session.
|
||||
|
||||
### List discovered skills
|
||||
|
||||
```bash
|
||||
hive skill list
|
||||
```
|
||||
|
||||
Output groups skills by scope:
|
||||
|
||||
```
|
||||
PROJECT SKILLS
|
||||
────────────────────────────────────
|
||||
• my-skill
|
||||
Does X when the user asks about Y.
|
||||
/home/user/project/.hive/skills/my-skill/SKILL.md
|
||||
|
||||
USER SKILLS
|
||||
────────────────────────────────────
|
||||
• deep-research
|
||||
Multi-step web research with source verification.
|
||||
/home/user/.hive/skills/deep-research/SKILL.md
|
||||
```
|
||||
|
||||
## Where to put skills
|
||||
|
||||
Hive scans five directories at startup, in this precedence order:
|
||||
|
||||
| Scope | Path | Use case |
|
||||
|-------|------|----------|
|
||||
| Project (Hive) | `<project>/.hive/skills/` | Skills specific to this repo |
|
||||
| Project (cross-client) | `<project>/.agents/skills/` | Skills shared across Claude Code, Cursor, etc. |
|
||||
| User (Hive) | `~/.hive/skills/` | Personal skills available in all projects |
|
||||
| User (cross-client) | `~/.agents/skills/` | Personal cross-client skills |
|
||||
| Framework | *(built-in)* | Default operational skills shipped with Hive |
|
||||
|
||||
**Precedence**: If two skills share the same name, the higher-precedence location wins. A project-level `code-review` skill overrides a user-level one with the same name.
|
||||
|
||||
**Cross-client paths**: The `.agents/skills/` directories are a convention shared across compatible agents. A skill installed at `~/.agents/skills/pdf-processing/` is visible to Hive, Claude Code, Cursor, and other compatible tools simultaneously.
|
||||
|
||||
## Creating a skill
|
||||
|
||||
### Directory structure
|
||||
|
||||
```
|
||||
my-skill/
|
||||
├── SKILL.md # Required — metadata + instructions
|
||||
├── scripts/ # Optional — executable code
|
||||
│ └── run.py
|
||||
├── references/ # Optional — supplementary docs
|
||||
│ └── api-reference.md
|
||||
└── assets/ # Optional — templates, data files
|
||||
└── template.json
|
||||
```
|
||||
|
||||
### SKILL.md format
|
||||
|
||||
Every skill needs a `SKILL.md` with YAML frontmatter and a markdown body:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: my-skill
|
||||
description: Extract and summarize PDF documents. Use when the user mentions PDFs or document extraction.
|
||||
---
|
||||
|
||||
# PDF Processing
|
||||
|
||||
## When to use
|
||||
Use this skill when the user needs to extract text from PDFs or merge documents.
|
||||
|
||||
## Steps
|
||||
1. Check if pdfplumber is available...
|
||||
2. Extract text using...
|
||||
|
||||
## Edge cases
|
||||
- Scanned PDFs need OCR first...
|
||||
```
|
||||
|
||||
### Frontmatter fields
|
||||
|
||||
| Field | Required | Description |
|
||||
|-------|----------|-------------|
|
||||
| `name` | Yes | Lowercase letters, numbers, hyphens. Must match the parent directory name. Max 64 chars. |
|
||||
| `description` | Yes | What the skill does and when to use it. Max 1024 chars. Include keywords that help the agent match tasks. |
|
||||
| `license` | No | License name or reference to a bundled LICENSE file. |
|
||||
| `compatibility` | No | Environment requirements (e.g., "Requires git, docker"). |
|
||||
| `metadata` | No | Arbitrary key-value pairs (author, version, etc.). |
|
||||
| `allowed-tools` | No | Space-delimited list of pre-approved tools. |
|
||||
|
||||
### Writing good descriptions
|
||||
|
||||
The description is critical — it's what the agent uses to decide whether to activate a skill. Be specific:
|
||||
|
||||
```yaml
|
||||
# Good — tells the agent what and when
|
||||
description: Extract text and tables from PDF files, fill PDF forms, and merge multiple PDFs. Use when working with PDF documents or when the user mentions PDFs, forms, or document extraction.
|
||||
|
||||
# Bad — too vague for the agent to match
|
||||
description: Helps with PDFs.
|
||||
```
|
||||
|
||||
### Writing good instructions
|
||||
|
||||
The markdown body is loaded into the agent's context when the skill is activated. Tips:
|
||||
|
||||
- **Be procedural**: Step-by-step instructions work better than abstract descriptions.
|
||||
- **Keep it focused**: Stay under 500 lines / 5000 tokens. Move detailed reference material to `references/`.
|
||||
- **Use relative paths**: Reference bundled files with relative paths (`scripts/run.py`, `references/guide.md`).
|
||||
- **Include examples**: Show sample inputs and expected outputs.
|
||||
- **Cover edge cases**: Tell the agent what to do when things go wrong.
|
||||
|
||||
## How skills are activated
|
||||
|
||||
Skills use **progressive disclosure** — three tiers that keep context usage efficient:
|
||||
|
||||
### Tier 1: Catalog (always loaded)
|
||||
|
||||
At session start, the agent sees a compact catalog of all available skills (name + description only, ~50-100 tokens each). This is how it knows what skills exist.
|
||||
|
||||
### Tier 2: Instructions (on demand)
|
||||
|
||||
When the agent determines a skill is relevant to the current task, it reads the full `SKILL.md` body into context. This happens automatically — the agent matches the task against skill descriptions and activates the best fit.
|
||||
|
||||
### Tier 3: Resources (on demand)
|
||||
|
||||
When skill instructions reference supporting files (`scripts/extract.py`, `references/api-docs.md`), the agent reads those individually as needed.
|
||||
|
||||
### Pre-activated skills
|
||||
|
||||
Some agents are configured to load specific skills at session start (skipping the catalog phase). This is set in the agent's configuration:
|
||||
|
||||
```python
|
||||
# In agent definition
|
||||
skills = ["code-review", "deep-research"]
|
||||
```
|
||||
|
||||
Pre-activated skills have their full instructions loaded from the start, without waiting for the agent to decide they're relevant.
|
||||
|
||||
## Trust and security
|
||||
|
||||
### Why trust gating exists
|
||||
|
||||
Project-level skills come from the repository being worked on. If you clone an untrusted repo that contains a `.hive/skills/` directory, those skills could inject instructions into the agent's system prompt. Trust gating prevents this.
|
||||
|
||||
**User-level and framework skills are always trusted.** Only project-scope skills go through trust gating.
|
||||
|
||||
### What happens with untrusted project skills
|
||||
|
||||
When Hive encounters project-level skills from a repo you haven't trusted before, it shows a consent prompt:
|
||||
|
||||
```
|
||||
============================================================
|
||||
SKILL TRUST REQUIRED
|
||||
============================================================
|
||||
|
||||
The project at /home/user/new-project wants to load 2 skill(s)
|
||||
that will inject instructions into the agent's system prompt.
|
||||
Source: github.com/org/new-project
|
||||
|
||||
Skills requesting access:
|
||||
• deploy-pipeline
|
||||
"Automated deployment workflow for this project."
|
||||
/home/user/new-project/.hive/skills/deploy-pipeline/SKILL.md
|
||||
• code-standards
|
||||
"Project-specific coding standards and review checklist."
|
||||
/home/user/new-project/.hive/skills/code-standards/SKILL.md
|
||||
|
||||
Options:
|
||||
1) Trust this session only
|
||||
2) Trust permanently — remember for future runs
|
||||
3) Deny — skip all project-scope skills from this repo
|
||||
────────────────────────────────────────────────────────────
|
||||
Select option (1-3):
|
||||
```
|
||||
|
||||
### Trust a repo via CLI
|
||||
|
||||
To trust a repo permanently without the interactive prompt:
|
||||
|
||||
```bash
|
||||
hive skill trust /path/to/project
|
||||
```
|
||||
|
||||
This stores the trust decision in `~/.hive/trusted_repos.json`, keyed by the normalized git remote URL (e.g., `github.com/org/repo`).
|
||||
|
||||
### Automatic trust
|
||||
|
||||
Some repos are trusted automatically:
|
||||
|
||||
- **No git repo**: Directories without `.git/` are always trusted.
|
||||
- **No remote**: Local-only git repos (no `origin` remote) are always trusted.
|
||||
- **Localhost remotes**: Repos with `localhost`/`127.0.0.1` remotes are always trusted.
|
||||
- **Own-remote patterns**: Repos matching patterns in `~/.hive/own_remotes` or the `HIVE_OWN_REMOTES` env var are always trusted.
|
||||
|
||||
### Configure own-remote patterns
|
||||
|
||||
If you trust all repos from your organization:
|
||||
|
||||
```bash
|
||||
# Via file (one pattern per line)
|
||||
echo "github.com/my-org/*" >> ~/.hive/own_remotes
|
||||
echo "gitlab.com/my-team/*" >> ~/.hive/own_remotes
|
||||
|
||||
# Via environment variable (comma-separated)
|
||||
export HIVE_OWN_REMOTES="github.com/my-org/*,github.com/my-corp/*"
|
||||
```
|
||||
|
||||
### CI / headless environments
|
||||
|
||||
In non-interactive environments, untrusted project skills are silently skipped. To trust them explicitly:
|
||||
|
||||
```bash
|
||||
export HIVE_TRUST_PROJECT_SKILLS=1
|
||||
hive run my-agent
|
||||
```
|
||||
|
||||
## Default skills
|
||||
|
||||
Hive ships with six built-in operational skills that provide runtime resilience. These are always loaded (unless disabled) and appear as "Operational Protocols" in the agent's system prompt.
|
||||
|
||||
| Skill | Purpose |
|
||||
|-------|---------|
|
||||
| `hive.note-taking` | Structured working notes in shared memory |
|
||||
| `hive.batch-ledger` | Track per-item status in batch operations |
|
||||
| `hive.context-preservation` | Save context before context window pruning |
|
||||
| `hive.quality-monitor` | Self-assess output quality periodically |
|
||||
| `hive.error-recovery` | Structured error classification and recovery |
|
||||
| `hive.task-decomposition` | Break complex tasks into subtasks |
|
||||
|
||||
### Disable default skills
|
||||
|
||||
In your agent configuration:
|
||||
|
||||
```python
|
||||
# Disable a specific default skill
|
||||
default_skills = {
|
||||
"hive.quality-monitor": {"enabled": False},
|
||||
}
|
||||
|
||||
# Disable all default skills
|
||||
default_skills = {
|
||||
"_all": {"enabled": False},
|
||||
}
|
||||
```
|
||||
|
||||
## Environment variables
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `HIVE_TRUST_PROJECT_SKILLS=1` | Bypass trust gating for all project-level skills (CI override) |
|
||||
| `HIVE_OWN_REMOTES` | Comma-separated glob patterns for auto-trusted remotes (e.g., `github.com/myorg/*`) |
|
||||
|
||||
## Compatibility with other agents
|
||||
|
||||
Skills written for any Agent Skills-compatible agent work in Hive:
|
||||
|
||||
- Place them in `.agents/skills/` (cross-client) or `.hive/skills/` (Hive-specific).
|
||||
- The `SKILL.md` format is identical across Claude Code, Cursor, Gemini CLI, and others.
|
||||
- Skills installed at `~/.agents/skills/` are visible to all compatible agents on your machine.
|
||||
|
||||
See the [Agent Skills specification](https://agentskills.io/specification) for the full format reference.
|
||||
@@ -0,0 +1,136 @@
|
||||
# SDR Agent
|
||||
|
||||
An AI-powered sales development outreach automation template for [Hive](https://github.com/aden-hive/hive).
|
||||
|
||||
Score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts — all with human review before anything is sent.
|
||||
|
||||
## Overview
|
||||
|
||||
The SDR Agent automates the full outreach pipeline:
|
||||
|
||||
```
|
||||
Intake → Score Contacts → Filter Contacts → Personalize → Send Outreach → Report
|
||||
```
|
||||
|
||||
1. **Intake** — Accept a contact list and outreach goal; confirm strategy with user
|
||||
2. **Score Contacts** — Rank contacts 0–100 using priority factors (alumni, degree, domain, etc.)
|
||||
3. **Filter Contacts** — Detect and skip suspicious/fake profiles (risk score ≥ 7)
|
||||
4. **Personalize** — Generate an 80–120 word personalized message per contact
|
||||
5. **Send Outreach** — Create Gmail drafts for human review (never sends automatically)
|
||||
6. **Report** — Summarize campaign: contacts scored, filtered, drafted
|
||||
|
||||
## Quickstart
|
||||
|
||||
```bash
|
||||
cd examples/templates/sdr_agent
|
||||
|
||||
# Run interactively via TUI
|
||||
python -m sdr_agent tui
|
||||
|
||||
# Run via CLI with a contacts JSON string
|
||||
python -m sdr_agent run \
|
||||
--contacts '[{"name":"Jane Doe","company":"Acme","title":"Engineer","connection_degree":"2nd","is_alumni":true}]' \
|
||||
--goal "coffee chat" \
|
||||
--background "Learning Technologist at UWO" \
|
||||
--max-contacts 20
|
||||
|
||||
# Validate agent structure
|
||||
python -m sdr_agent validate
|
||||
```
|
||||
|
||||
## Contact Schema
|
||||
|
||||
Each contact in your list supports the following fields:
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `name` | string | ✅ | Contact's full name |
|
||||
| `email` | string | ❌ | Email address (draft placeholder if missing) |
|
||||
| `company` | string | ✅ | Current company |
|
||||
| `title` | string | ✅ | Job title |
|
||||
| `linkedin_url` | string | ❌ | LinkedIn profile URL |
|
||||
| `connection_degree` | string | ❌ | `"1st"`, `"2nd"`, or `"3rd"` |
|
||||
| `is_alumni` | boolean | ❌ | Shares school with user |
|
||||
| `school_name` | string | ❌ | School name for alumni messaging |
|
||||
| `connections_count` | integer | ❌ | Number of LinkedIn connections |
|
||||
| `mutual_connections` | integer | ❌ | Count of mutual connections |
|
||||
| `has_photo` | boolean | ❌ | Has a profile photo |
|
||||
|
||||
## Scoring Model
|
||||
|
||||
The `score-contacts` node ranks each contact 0–100:
|
||||
|
||||
| Factor | Points |
|
||||
|--------|--------|
|
||||
| Alumni | +30 |
|
||||
| 1st degree | +25 |
|
||||
| 2nd degree | +20 |
|
||||
| 3rd degree | +10 |
|
||||
| Domain verified | +10 |
|
||||
| Mutual connections (×1, max 10) | +10 |
|
||||
| Active job posting | +10 |
|
||||
| Has profile photo | +5 |
|
||||
| 500+ connections | +5 |
|
||||
|
||||
## Scam Detection
|
||||
|
||||
The `filter-contacts` node calculates a risk score and excludes contacts with risk ≥ 7:
|
||||
|
||||
| Red Flag | Risk |
|
||||
|----------|------|
|
||||
| Fewer than 50 connections | +3 |
|
||||
| No profile photo | +2 |
|
||||
| Fewer than 2 work positions | +2 |
|
||||
| Generic title + few connections | +2 |
|
||||
| Unverifiable company | +2 |
|
||||
| AI-generated-looking profile | +2 |
|
||||
| 5000+ connections, 0 mutual | +1 |
|
||||
|
||||
## Pipeline Output Files
|
||||
|
||||
Each run writes to `~/.hive/agents/sdr_agent/data/`:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| `contacts.jsonl` | Raw contact list |
|
||||
| `scored_contacts.jsonl` | Contacts with `priority_score` |
|
||||
| `safe_contacts.jsonl` | Contacts passing scam filter |
|
||||
| `personalized_contacts.jsonl` | Contacts with `outreach_message` |
|
||||
| `drafts.jsonl` | Draft creation records |
|
||||
|
||||
## Safety Constraints
|
||||
|
||||
- **Never sends emails** — only `gmail_create_draft` is called; human must review and send
|
||||
- **Batch limit** — processes at most `max_contacts` per run (default: 20)
|
||||
- **Skip suspicious** — contacts with `risk_score ≥ 7` are always excluded
|
||||
|
||||
## Tools Required
|
||||
|
||||
- `gmail_create_draft` — create Gmail draft for each contact
|
||||
- `load_data` — read JSONL data files
|
||||
- `append_data` — write to JSONL data files
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ SDR Agent │
|
||||
│ │
|
||||
│ ┌────────┐ ┌───────────────┐ ┌────────────────┐ │
|
||||
│ │ Intake │──▶│ Score Contacts│──▶│ Filter Contacts│ │
|
||||
│ └────────┘ └───────────────┘ └────────────────┘ │
|
||||
│ ▲ │ │
|
||||
│ │ ▼ │
|
||||
│ ┌────────┐ ┌───────────────┐ ┌─────────────┐ │
|
||||
│ │ Report │◀──│ Send Outreach │◀──│ Personalize │ │
|
||||
│ └────────┘ └───────────────┘ └─────────────┘ │
|
||||
│ │
|
||||
│ ● client_facing nodes: intake, report │
|
||||
│ ● automated nodes: score-contacts, filter-contacts, │
|
||||
│ personalize, send-outreach │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Inspiration
|
||||
|
||||
This template is inspired by real-world SDR automation patterns, including contact ranking, scam detection, and two-step personalization (hook extraction → message generation) — demonstrating how job-search and sales outreach workflows can be modeled as AI agent pipelines in Hive.
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
SDR Agent — Automated sales development outreach pipeline.
|
||||
|
||||
Score contacts by priority, filter suspicious profiles, generate personalized
|
||||
outreach messages, and create Gmail drafts for human review before sending.
|
||||
"""
|
||||
|
||||
from .agent import (
|
||||
SDRAgent,
|
||||
default_agent,
|
||||
goal,
|
||||
nodes,
|
||||
edges,
|
||||
loop_config,
|
||||
async_entry_points,
|
||||
entry_node,
|
||||
entry_points,
|
||||
pause_nodes,
|
||||
terminal_nodes,
|
||||
conversation_mode,
|
||||
identity_prompt,
|
||||
)
|
||||
from .config import RuntimeConfig, AgentMetadata, default_config, metadata
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
__all__ = [
|
||||
"SDRAgent",
|
||||
"default_agent",
|
||||
"goal",
|
||||
"nodes",
|
||||
"edges",
|
||||
"loop_config",
|
||||
"async_entry_points",
|
||||
"entry_node",
|
||||
"entry_points",
|
||||
"pause_nodes",
|
||||
"terminal_nodes",
|
||||
"conversation_mode",
|
||||
"identity_prompt",
|
||||
"RuntimeConfig",
|
||||
"AgentMetadata",
|
||||
"default_config",
|
||||
"metadata",
|
||||
]
|
||||
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
CLI entry point for SDR Agent.
|
||||
|
||||
Automates sales development outreach: score contacts, filter suspicious
|
||||
profiles, generate personalized messages, and create Gmail drafts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import click
|
||||
|
||||
from .agent import default_agent, SDRAgent
|
||||
|
||||
|
||||
def setup_logging(verbose=False, debug=False):
|
||||
"""Configure logging for execution visibility."""
|
||||
if debug:
|
||||
level, fmt = logging.DEBUG, "%(asctime)s %(name)s: %(message)s"
|
||||
elif verbose:
|
||||
level, fmt = logging.INFO, "%(message)s"
|
||||
else:
|
||||
level, fmt = logging.WARNING, "%(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=level, format=fmt, stream=sys.stderr)
|
||||
logging.getLogger("framework").setLevel(level)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version="1.0.0")
|
||||
def cli():
|
||||
"""SDR Agent - Automated outreach with contact scoring and personalization."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--contacts",
|
||||
"-c",
|
||||
type=str,
|
||||
required=True,
|
||||
help="JSON string or file path of contacts list",
|
||||
)
|
||||
@click.option(
|
||||
"--goal",
|
||||
"-g",
|
||||
type=str,
|
||||
default="coffee chat",
|
||||
help="Outreach goal (e.g. 'coffee chat', 'sales pitch')",
|
||||
)
|
||||
@click.option(
|
||||
"--background",
|
||||
"-b",
|
||||
type=str,
|
||||
default="",
|
||||
help="Your background/role for personalization",
|
||||
)
|
||||
@click.option(
|
||||
"--max-contacts",
|
||||
"-m",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max contacts to process per batch (default: 20)",
|
||||
)
|
||||
@click.option(
|
||||
"--mock", is_flag=True, help="Run in mock mode without LLM or Gmail calls"
|
||||
)
|
||||
@click.option("--quiet", "-q", is_flag=True, help="Only output result JSON")
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Show execution details")
|
||||
@click.option("--debug", is_flag=True, help="Show debug logging")
|
||||
def run(contacts, goal, background, max_contacts, mock, quiet, verbose, debug):
|
||||
"""Execute an SDR outreach campaign for the given contacts."""
|
||||
if not quiet:
|
||||
setup_logging(verbose=verbose, debug=debug)
|
||||
|
||||
context = {
|
||||
"contacts": contacts,
|
||||
"outreach_goal": goal,
|
||||
"user_background": background,
|
||||
"max_contacts": str(max_contacts),
|
||||
}
|
||||
|
||||
result = asyncio.run(default_agent.run(context, mock_mode=mock))
|
||||
|
||||
output_data = {
|
||||
"success": result.success,
|
||||
"steps_executed": result.steps_executed,
|
||||
"output": result.output,
|
||||
}
|
||||
if result.error:
|
||||
output_data["error"] = result.error
|
||||
|
||||
click.echo(json.dumps(output_data, indent=2, default=str))
|
||||
sys.exit(0 if result.success else 1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--mock", is_flag=True, help="Run in mock mode")
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Show execution details")
|
||||
@click.option("--debug", is_flag=True, help="Show debug logging")
|
||||
def tui(mock, verbose, debug):
|
||||
"""Launch the TUI dashboard for interactive SDR outreach."""
|
||||
setup_logging(verbose=verbose, debug=debug)
|
||||
|
||||
try:
|
||||
from framework.tui.app import AdenTUI
|
||||
except ImportError:
|
||||
click.echo(
|
||||
"TUI requires the 'textual' package. Install with: pip install textual"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
async def run_with_tui():
|
||||
agent = SDRAgent()
|
||||
await agent.start(mock_mode=mock)
|
||||
|
||||
try:
|
||||
app = AdenTUI(agent._agent_runtime)
|
||||
await app.run_async()
|
||||
finally:
|
||||
await agent.stop()
|
||||
|
||||
asyncio.run(run_with_tui())
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--json", "output_json", is_flag=True)
|
||||
def info(output_json):
|
||||
"""Show agent information."""
|
||||
info_data = default_agent.info()
|
||||
if output_json:
|
||||
click.echo(json.dumps(info_data, indent=2))
|
||||
else:
|
||||
click.echo(f"Agent: {info_data['name']}")
|
||||
click.echo(f"Version: {info_data['version']}")
|
||||
click.echo(f"Description: {info_data['description']}")
|
||||
click.echo(f"\nNodes: {', '.join(info_data['nodes'])}")
|
||||
click.echo(f"Client-facing: {', '.join(info_data['client_facing_nodes'])}")
|
||||
click.echo(f"Entry: {info_data['entry_node']}")
|
||||
click.echo(f"Terminal: {', '.join(info_data['terminal_nodes'])}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def validate():
|
||||
"""Validate agent structure."""
|
||||
validation = default_agent.validate()
|
||||
if validation["valid"]:
|
||||
click.echo("Agent is valid")
|
||||
if validation["warnings"]:
|
||||
for warning in validation["warnings"]:
|
||||
click.echo(f" WARNING: {warning}")
|
||||
else:
|
||||
click.echo("Agent has errors:")
|
||||
for error in validation["errors"]:
|
||||
click.echo(f" ERROR: {error}")
|
||||
sys.exit(0 if validation["valid"] else 1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--verbose", "-v", is_flag=True)
|
||||
def shell(verbose):
|
||||
"""Interactive SDR outreach session (CLI, no TUI)."""
|
||||
asyncio.run(_interactive_shell(verbose))
|
||||
|
||||
|
||||
async def _interactive_shell(verbose=False):
|
||||
"""Async interactive shell."""
|
||||
setup_logging(verbose=verbose)
|
||||
|
||||
click.echo("=== SDR Agent ===")
|
||||
click.echo("Automated contact scoring, filtering, and outreach personalization\n")
|
||||
|
||||
agent = SDRAgent()
|
||||
await agent.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
goal = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Outreach goal (e.g. 'coffee chat')> "
|
||||
)
|
||||
if goal.lower() in ["quit", "exit", "q"]:
|
||||
click.echo("Goodbye!")
|
||||
break
|
||||
|
||||
contacts = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Contacts (JSON)> "
|
||||
)
|
||||
background = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "Your background/role> "
|
||||
)
|
||||
|
||||
if not contacts.strip():
|
||||
continue
|
||||
|
||||
click.echo("\nRunning SDR campaign...\n")
|
||||
|
||||
result = await agent.trigger_and_wait(
|
||||
"start",
|
||||
{
|
||||
"contacts": contacts,
|
||||
"outreach_goal": goal,
|
||||
"user_background": background,
|
||||
"max_contacts": "20",
|
||||
},
|
||||
)
|
||||
|
||||
if result is None:
|
||||
click.echo("\n[Execution timed out]\n")
|
||||
continue
|
||||
|
||||
if result.success:
|
||||
output = result.output
|
||||
if "summary_report" in output:
|
||||
click.echo("\n--- Campaign Report ---\n")
|
||||
click.echo(output["summary_report"])
|
||||
click.echo("\n")
|
||||
else:
|
||||
click.echo(f"\nCampaign failed: {result.error}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await agent.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,378 @@
|
||||
{
|
||||
"agent": {
|
||||
"id": "sdr_agent",
|
||||
"name": "SDR Agent",
|
||||
"version": "1.0.0",
|
||||
"description": "Automate sales development outreach using AI-powered contact scoring, scam detection, and personalized message generation. Score contacts by priority, filter suspicious profiles, generate personalized outreach messages, and create Gmail drafts for review — all without sending emails automatically."
|
||||
},
|
||||
"graph": {
|
||||
"id": "sdr-agent-graph",
|
||||
"goal_id": "sdr-agent",
|
||||
"version": "1.0.0",
|
||||
"entry_node": "intake",
|
||||
"entry_points": {
|
||||
"start": "intake"
|
||||
},
|
||||
"pause_nodes": [],
|
||||
"terminal_nodes": ["complete"],
|
||||
"conversation_mode": "continuous",
|
||||
"identity_prompt": "You are an SDR (Sales Development Representative) assistant. You help users automate their outreach by scoring contacts, filtering suspicious profiles, generating personalized messages, and creating Gmail drafts — all with human review before anything is sent.",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "intake",
|
||||
"name": "Intake",
|
||||
"description": "Receive the contact list and outreach goal from the user. Confirm the strategy and batch size before proceeding.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are an SDR (Sales Development Representative) assistant helping automate outreach.\n\n**STEP 1 — Respond to the user (text only, NO tool calls):**\n\nRead the user's input from context. Confirm your understanding of:\n- The contact list they provided (or ask them to provide one)\n- Their outreach goal (e.g. \"coffee chat\", \"sales pitch\", \"networking\")\n- Their background/role (used to personalize messages)\n- The batch size (max_contacts). Default to 20 if not specified.\n\nPresent a summary like:\n\"Here's what I'll do:\n1. Score and rank your contacts by priority (alumni status, connection degree, etc.)\n2. Filter out suspicious or low-quality profiles (risk score ≥ 7)\n3. Generate a personalized outreach message for each contact\n4. Create Gmail draft emails for your review — I never send automatically\n\nReady to proceed with [N] contacts for [goal]?\"\n\n**STEP 2 — After the user confirms, call set_output:**\n\n- set_output(\"contacts\", <the contact list as a JSON string>)\n- set_output(\"outreach_goal\", <the confirmed goal, e.g. \"coffee chat\">)\n- set_output(\"max_contacts\", <the confirmed batch size as a string, e.g. \"20\">)\n- set_output(\"user_background\", <user's background/role, e.g. \"Learning Technologist at UWO\">)",
|
||||
"tools": [],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": true,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "score-contacts",
|
||||
"name": "Score Contacts",
|
||||
"description": "Score and rank each contact from 0 to 100 based on priority factors: alumni status, connection degree, domain verification, mutual connections, and active job postings.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are a contact prioritization engine. Score each contact from 0 to 100.\n\n**SCORING RULES (additive):**\n- Alumni of the user's school: +30 points\n- 1st degree connection: +25 points\n- 2nd degree connection: +20 points\n- 3rd degree connection: +10 points\n- Domain verified (company email matches LinkedIn company): +10 points\n- Has mutual connections (1 point each, max 10): up to +10 points\n- Active job posting at their company: +10 points\n- Has a profile photo: +5 points\n- Over 500 connections: +5 points\n\nCap the final score at 100.\n\n**STEP 1 — Load the contacts:**\nCall load_data(filename=\"contacts.jsonl\") to read the contact list.\nIf \"contacts\" in context is a JSON string (not a filename), write it first:\n- For each contact in the list, call append_data(filename=\"contacts.jsonl\", data=<JSON contact object>)\nThen read it back.\n\n**STEP 2 — Score each contact:**\nFor each contact, calculate the priority score using the rules above.\nAdd a \"priority_score\" field to each contact object.\n\n**STEP 3 — Write scored contacts and set output:**\n- Call append_data(filename=\"scored_contacts.jsonl\", data=<JSON contact with priority_score>) for each contact.\n- Sort contacts by priority_score (highest first) in your final output.\n- Call set_output(\"scored_contacts\", \"scored_contacts.jsonl\")",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "filter-contacts",
|
||||
"name": "Filter Contacts",
|
||||
"description": "Analyze each contact for authenticity and filter out suspicious profiles. Any contact with a risk score of 7 or higher is skipped.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"output_keys": [
|
||||
"safe_contacts",
|
||||
"filtered_count"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are a profile authenticity analyzer. Your job is to detect suspicious or fake LinkedIn profiles.\n\n**RISK SCORING RULES (additive):**\n- Fewer than 50 connections: +3 points\n- No profile photo: +2 points\n- Fewer than 2 positions in work history: +2 points\n- Generic title (e.g. \"entrepreneur\", \"CEO\", \"consultant\") AND fewer than 100 connections: +2 points\n- Company name appears generic or unverifiable: +2 points\n- Profile text seems auto-generated or overly promotional: +2 points\n- Connection count over 5000 with no mutual connections: +1 point\n\n**DECISION RULE:**\n- risk_score < 4: SAFE — include in outreach\n- risk_score 4–6: CAUTION — include but flag\n- risk_score ≥ 7: SKIP — exclude from outreach\n\n**STEP 1 — Load scored contacts:**\nCall load_data(filename=<the \"scored_contacts\" value from context>).\nProcess contacts chunk by chunk if has_more=true.\n\n**STEP 2 — Analyze each contact:**\nFor each contact, calculate a risk_score using the rules above.\nDetermine: is_safe (risk_score < 7), recommendation (safe/caution/skip), flags (list of triggered rules).\n\n**STEP 3 — Write safe contacts and set output:**\n- For each contact where risk_score < 7: call append_data(filename=\"safe_contacts.jsonl\", data=<contact JSON with risk_score and flags added>)\n- Track how many contacts were filtered (risk_score ≥ 7)\n- Call set_output(\"safe_contacts\", \"safe_contacts.jsonl\")\n- Call set_output(\"filtered_count\", <number of skipped contacts as string>)",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "personalize",
|
||||
"name": "Personalize",
|
||||
"description": "Generate a personalized outreach message for each contact based on their profile, shared background, and the user's outreach goal.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"safe_contacts",
|
||||
"outreach_goal",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"personalized_contacts"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are a professional outreach message writer. Generate personalized messages for each contact.\n\n**TWO-STEP PERSONALIZATION:**\n\nFor each contact, follow this two-step approach:\n\nSTEP A — Extract hooks (analyze the profile):\nLook for 2-3 specific talking points from the contact's profile:\n- Shared alumni connection\n- Specific role, company, or career transition worth mentioning\n- Any mutual interests aligned with the user's background\n\nSTEP B — Generate the message:\nWrite a warm, professional outreach message using the hooks.\n\n**MESSAGE REQUIREMENTS:**\n- 80-120 words (LinkedIn message length)\n- Start with a specific observation (\"I noticed you...\" or \"Fellow [school] alum here...\")\n- Mention the shared connection or interest naturally\n- State the outreach goal clearly but softly (e.g. \"Open to a brief 15-min chat?\")\n- Professional but warm tone — NOT templated or AI-sounding\n- Do NOT mention job postings directly unless the goal is job-related\n- Do NOT use generic openers like \"I hope this finds you well\"\n- End with a low-pressure ask\n\n**STEP 1 — Load safe contacts:**\nCall load_data(filename=<the \"safe_contacts\" value from context>).\n\n**STEP 2 — Generate message for each contact:**\nFor each contact: generate the personalized message using the two-step approach above.\nAdd \"outreach_message\" field to each contact object.\n\n**STEP 3 — Write output and set:**\n- Call append_data(filename=\"personalized_contacts.jsonl\", data=<contact JSON with outreach_message>) for each.\n- Call set_output(\"personalized_contacts\", \"personalized_contacts.jsonl\")",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "send-outreach",
|
||||
"name": "Send Outreach",
|
||||
"description": "Create Gmail draft emails for each contact using their personalized message. Drafts are created for human review — emails are never sent automatically.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"personalized_contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"drafts_created"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are an outreach execution assistant. Create Gmail draft emails for each contact.\n\n**CRITICAL RULE: NEVER send emails automatically. Only create drafts.**\n\n**STEP 1 — Load personalized contacts:**\nCall load_data(filename=<the \"personalized_contacts\" value from context>).\nProcess chunk by chunk if has_more=true.\n\n**STEP 2 — Create Gmail draft for each contact:**\nFor each contact with an \"outreach_message\":\n- subject: \"Coffee Chat Request\" (or appropriate subject based on outreach_goal)\n- to: contact's email address (use LinkedIn profile URL if email not available — note this in body)\n- body: the \"outreach_message\" from the contact object\n\nCall gmail_create_draft(\n to=<contact email or linkedin_url as placeholder>,\n subject=<appropriate subject line>,\n body=<outreach_message>\n)\n\nRecord each draft: call append_data(\n filename=\"drafts.jsonl\",\n data=<JSON: {contact_name, contact_email, subject, status: \"draft_created\"}>\n)\n\n**STEP 3 — Set output:**\n- Call set_output(\"drafts_created\", \"drafts.jsonl\")\n\n**IMPORTANT:** If a contact has no email address, create the draft with their LinkedIn URL as a placeholder and add a note in the body: \"Note: Please find the recipient's email before sending.\"",
|
||||
"tools": [
|
||||
"gmail_create_draft",
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "report",
|
||||
"name": "Report",
|
||||
"description": "Generate a summary report of the outreach campaign: contacts scored, filtered, messaged, and drafts created. Present to user for review.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"drafts_created",
|
||||
"filtered_count",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"summary_report"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are an SDR assistant. Generate a clear campaign summary report and present it to the user.\n\n**STEP 1 — Load draft records:**\nCall load_data(filename=<the \"drafts_created\" value from context>) to read the draft records.\nIf has_more=true, load additional chunks until all records are loaded.\n\n**STEP 2 — Present the report (text only, NO tool calls):**\n\nPresent a clean summary:\n\n📊 **SDR Campaign Summary — [outreach_goal]**\n\n**Overview:**\n- Total contacts processed: [N]\n- Contacts filtered (suspicious profiles): [filtered_count]\n- Safe contacts messaged: [N - filtered_count]\n- Gmail drafts created: [N]\n\n**Drafts Created:**\nList each draft: Contact Name | Company | Subject\n\n**Next Steps:**\n\"Your Gmail drafts are ready for review. Please:\n1. Open Gmail and review each draft\n2. Personalize further if needed\n3. Send when ready\n\nCampaign complete!\"\n\n**STEP 3 — After the user responds, call set_output:**\n- set_output(\"summary_report\", <the formatted report text>)",
|
||||
"tools": [
|
||||
"load_data"
|
||||
],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 0,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": true,
|
||||
"success_criteria": null
|
||||
},
|
||||
{
|
||||
"id": "complete",
|
||||
"name": "Complete",
|
||||
"description": "Terminal node - campaign complete.",
|
||||
"node_type": "event_loop",
|
||||
"input_keys": [
|
||||
"summary_report"
|
||||
],
|
||||
"output_keys": [
|
||||
"final_report"
|
||||
],
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "Campaign is complete. Set the final output.\n\nCall set_output(\"final_report\", <summary_report value from context>)",
|
||||
"tools": [],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
"max_retries": 3,
|
||||
"retry_on": [],
|
||||
"max_node_visits": 1,
|
||||
"output_model": null,
|
||||
"max_validation_retries": 2,
|
||||
"client_facing": false,
|
||||
"success_criteria": null
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "intake-to-score",
|
||||
"source": "intake",
|
||||
"target": "score-contacts",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "score-to-filter",
|
||||
"source": "score-contacts",
|
||||
"target": "filter-contacts",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "filter-to-personalize",
|
||||
"source": "filter-contacts",
|
||||
"target": "personalize",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "personalize-to-send",
|
||||
"source": "personalize",
|
||||
"target": "send-outreach",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "send-to-report",
|
||||
"source": "send-outreach",
|
||||
"target": "report",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
},
|
||||
{
|
||||
"id": "report-to-complete",
|
||||
"source": "report",
|
||||
"target": "complete",
|
||||
"condition": "on_success",
|
||||
"condition_expr": null,
|
||||
"priority": 1,
|
||||
"input_mapping": {}
|
||||
}
|
||||
],
|
||||
"max_steps": 100,
|
||||
"max_retries_per_node": 3,
|
||||
"description": "Automated SDR outreach pipeline: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review."
|
||||
},
|
||||
"goal": {
|
||||
"id": "sdr-agent",
|
||||
"name": "SDR Agent",
|
||||
"description": "Automate sales development outreach: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review.",
|
||||
"status": "draft",
|
||||
"success_criteria": [
|
||||
{
|
||||
"id": "contact-scoring-accuracy",
|
||||
"description": "Contacts are correctly scored and ranked by priority factors (alumni status, connection degree, domain verification)",
|
||||
"metric": "scoring_accuracy",
|
||||
"target": ">=90%",
|
||||
"weight": 0.30,
|
||||
"met": false
|
||||
},
|
||||
{
|
||||
"id": "scam-filter-effectiveness",
|
||||
"description": "Suspicious profiles (risk_score >= 7) are correctly identified and excluded from outreach",
|
||||
"metric": "filter_precision",
|
||||
"target": ">=95%",
|
||||
"weight": 0.25,
|
||||
"met": false
|
||||
},
|
||||
{
|
||||
"id": "message-personalization",
|
||||
"description": "Generated messages reference specific profile details (alumni connection, role, company) and match the outreach goal",
|
||||
"metric": "personalization_score",
|
||||
"target": ">=80%",
|
||||
"weight": 0.30,
|
||||
"met": false
|
||||
},
|
||||
{
|
||||
"id": "draft-creation",
|
||||
"description": "Gmail drafts are created for all safe contacts without errors",
|
||||
"metric": "draft_success_rate",
|
||||
"target": "100%",
|
||||
"weight": 0.15,
|
||||
"met": false
|
||||
}
|
||||
],
|
||||
"constraints": [
|
||||
{
|
||||
"id": "draft-not-send",
|
||||
"description": "Agent creates Gmail drafts but NEVER sends emails automatically",
|
||||
"constraint_type": "hard",
|
||||
"category": "safety",
|
||||
"check": ""
|
||||
},
|
||||
{
|
||||
"id": "respect-batch-limit",
|
||||
"description": "Must not process more contacts than the configured max_contacts parameter",
|
||||
"constraint_type": "hard",
|
||||
"category": "operational",
|
||||
"check": ""
|
||||
},
|
||||
{
|
||||
"id": "skip-suspicious",
|
||||
"description": "Contacts with risk_score >= 7 must be excluded from outreach",
|
||||
"constraint_type": "hard",
|
||||
"category": "safety",
|
||||
"check": ""
|
||||
}
|
||||
],
|
||||
"context": {},
|
||||
"required_capabilities": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"version": "1.0.0",
|
||||
"parent_version": null,
|
||||
"evolution_reason": null
|
||||
},
|
||||
"required_tools": [
|
||||
"gmail_create_draft",
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"metadata": {
|
||||
"node_count": 7,
|
||||
"edge_count": 6
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
"""Agent graph construction for SDR Agent."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.edge import AsyncEntryPointSpec, GraphSpec
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.llm import LiteLLMProvider
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
|
||||
from .config import default_config, metadata
|
||||
from .nodes import (
|
||||
intake_node,
|
||||
score_contacts_node,
|
||||
filter_contacts_node,
|
||||
personalize_node,
|
||||
send_outreach_node,
|
||||
report_node,
|
||||
)
|
||||
|
||||
# Goal definition
|
||||
goal = Goal(
|
||||
id="sdr-agent",
|
||||
name="SDR Agent",
|
||||
description=(
|
||||
"Automate sales development outreach: score contacts by priority, "
|
||||
"filter suspicious profiles, generate personalized messages, "
|
||||
"and create Gmail drafts for human review."
|
||||
),
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="contact-scoring-accuracy",
|
||||
description=(
|
||||
"Contacts are correctly scored and ranked by priority factors "
|
||||
"(alumni status, connection degree, domain verification)"
|
||||
),
|
||||
metric="scoring_accuracy",
|
||||
target=">=90%",
|
||||
weight=0.30,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="scam-filter-effectiveness",
|
||||
description=(
|
||||
"Suspicious profiles (risk_score >= 7) are correctly identified "
|
||||
"and excluded from outreach"
|
||||
),
|
||||
metric="filter_precision",
|
||||
target=">=95%",
|
||||
weight=0.25,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="message-personalization",
|
||||
description=(
|
||||
"Generated messages reference specific profile details "
|
||||
"(alumni connection, role, company) and match the outreach goal"
|
||||
),
|
||||
metric="personalization_score",
|
||||
target=">=80%",
|
||||
weight=0.30,
|
||||
),
|
||||
SuccessCriterion(
|
||||
id="draft-creation",
|
||||
description="Gmail drafts are created for all safe contacts without errors",
|
||||
metric="draft_success_rate",
|
||||
target="100%",
|
||||
weight=0.15,
|
||||
),
|
||||
],
|
||||
constraints=[
|
||||
Constraint(
|
||||
id="draft-not-send",
|
||||
description="Agent creates Gmail drafts but NEVER sends emails automatically",
|
||||
constraint_type="hard",
|
||||
category="safety",
|
||||
),
|
||||
Constraint(
|
||||
id="respect-batch-limit",
|
||||
description="Must not process more contacts than the configured max_contacts parameter",
|
||||
constraint_type="hard",
|
||||
category="operational",
|
||||
),
|
||||
Constraint(
|
||||
id="skip-suspicious",
|
||||
description="Contacts with risk_score >= 7 must be excluded from outreach",
|
||||
constraint_type="hard",
|
||||
category="safety",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Node list
|
||||
nodes = [
|
||||
intake_node,
|
||||
score_contacts_node,
|
||||
filter_contacts_node,
|
||||
personalize_node,
|
||||
send_outreach_node,
|
||||
report_node,
|
||||
]
|
||||
|
||||
# Edge definitions
|
||||
edges = [
|
||||
EdgeSpec(
|
||||
id="intake-to-score",
|
||||
source="intake",
|
||||
target="score-contacts",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="score-to-filter",
|
||||
source="score-contacts",
|
||||
target="filter-contacts",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="filter-to-personalize",
|
||||
source="filter-contacts",
|
||||
target="personalize",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="personalize-to-send",
|
||||
source="personalize",
|
||||
target="send-outreach",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="send-to-report",
|
||||
source="send-outreach",
|
||||
target="report",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
EdgeSpec(
|
||||
id="report-to-intake",
|
||||
source="report",
|
||||
target="intake",
|
||||
condition=EdgeCondition.ON_SUCCESS,
|
||||
priority=1,
|
||||
),
|
||||
]
|
||||
|
||||
# Graph configuration
|
||||
entry_node = "intake"
|
||||
entry_points = {"start": "intake"}
|
||||
async_entry_points: list[AsyncEntryPointSpec] = [] # SDR Agent is manually triggered
|
||||
pause_nodes = []
|
||||
terminal_nodes = []
|
||||
loop_config = {
|
||||
"max_iterations": 100,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_tool_result_chars": 8000,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
conversation_mode = "continuous"
|
||||
identity_prompt = (
|
||||
"You are an SDR (Sales Development Representative) assistant. "
|
||||
"You help users automate their outreach by scoring contacts, filtering "
|
||||
"suspicious profiles, generating personalized messages, and creating "
|
||||
"Gmail drafts — all with human review before anything is sent."
|
||||
)
|
||||
|
||||
|
||||
class SDRAgent:
|
||||
"""
|
||||
SDR Agent — 6-node pipeline for automated outreach.
|
||||
|
||||
Flow: intake -> score-contacts -> filter-contacts -> personalize
|
||||
-> send-outreach -> report -> intake (loop)
|
||||
|
||||
Pipeline:
|
||||
1. intake: Receive contact list and outreach goal
|
||||
2. score-contacts: Rank contacts 0-100 by priority factors
|
||||
3. filter-contacts: Remove suspicious profiles (risk >= 7)
|
||||
4. personalize: Generate personalized messages for each contact
|
||||
5. send-outreach: Create Gmail drafts (never sends automatically)
|
||||
6. report: Summarize campaign results and present to user
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config or default_config
|
||||
self.goal = goal
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.entry_node = entry_node
|
||||
self.entry_points = entry_points
|
||||
self.pause_nodes = pause_nodes
|
||||
self.terminal_nodes = terminal_nodes
|
||||
self._agent_runtime: AgentRuntime | None = None
|
||||
self._graph: GraphSpec | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
|
||||
def _build_graph(self) -> GraphSpec:
|
||||
"""Build the GraphSpec."""
|
||||
return GraphSpec(
|
||||
id="sdr-agent-graph",
|
||||
goal_id=self.goal.id,
|
||||
version="1.0.0",
|
||||
entry_node=self.entry_node,
|
||||
entry_points=self.entry_points,
|
||||
terminal_nodes=self.terminal_nodes,
|
||||
pause_nodes=self.pause_nodes,
|
||||
nodes=self.nodes,
|
||||
edges=self.edges,
|
||||
default_model=self.config.model,
|
||||
max_tokens=self.config.max_tokens,
|
||||
loop_config=loop_config,
|
||||
conversation_mode=conversation_mode,
|
||||
identity_prompt=identity_prompt,
|
||||
)
|
||||
|
||||
def _setup(self, mock_mode=False) -> None:
|
||||
"""Set up the agent runtime with sessions, checkpoints, and logging."""
|
||||
self._storage_path = Path.home() / ".hive" / "agents" / "sdr_agent"
|
||||
self._storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._tool_registry = ToolRegistry()
|
||||
|
||||
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
|
||||
if mcp_config_path.exists():
|
||||
self._tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
tools_path = Path(__file__).parent / "tools.py"
|
||||
if tools_path.exists():
|
||||
self._tool_registry.discover_from_module(tools_path)
|
||||
|
||||
if mock_mode:
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
llm = MockLLMProvider()
|
||||
else:
|
||||
llm = LiteLLMProvider(
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
api_base=self.config.api_base,
|
||||
)
|
||||
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
|
||||
self._graph = self._build_graph()
|
||||
|
||||
checkpoint_config = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
async_checkpoint=True,
|
||||
)
|
||||
|
||||
entry_point_specs = [
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
]
|
||||
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self._graph,
|
||||
goal=self.goal,
|
||||
storage_path=self._storage_path,
|
||||
entry_points=entry_point_specs,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
|
||||
async def start(self, mock_mode=False) -> None:
|
||||
"""Set up and start the agent runtime."""
|
||||
if self._agent_runtime is None:
|
||||
self._setup(mock_mode=mock_mode)
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the agent runtime and clean up."""
|
||||
if self._agent_runtime and self._agent_runtime.is_running:
|
||||
await self._agent_runtime.stop()
|
||||
self._agent_runtime = None
|
||||
|
||||
async def trigger_and_wait(
|
||||
self,
|
||||
entry_point: str,
|
||||
input_data: dict,
|
||||
timeout: float | None = None,
|
||||
session_state: dict | None = None,
|
||||
) -> ExecutionResult | None:
|
||||
"""Execute the graph and wait for completion."""
|
||||
if self._agent_runtime is None:
|
||||
raise RuntimeError("Agent not started. Call start() first.")
|
||||
|
||||
return await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point,
|
||||
input_data=input_data,
|
||||
timeout=timeout,
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, context: dict, mock_mode=False, session_state=None
|
||||
) -> ExecutionResult:
|
||||
"""Run the agent (convenience method for single execution)."""
|
||||
await self.start(mock_mode=mock_mode)
|
||||
try:
|
||||
result = await self.trigger_and_wait(
|
||||
"default", context, session_state=session_state
|
||||
)
|
||||
return result or ExecutionResult(success=False, error="Execution timeout")
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
def info(self):
|
||||
"""Get agent information."""
|
||||
return {
|
||||
"name": metadata.name,
|
||||
"version": metadata.version,
|
||||
"description": metadata.description,
|
||||
"goal": {
|
||||
"name": self.goal.name,
|
||||
"description": self.goal.description,
|
||||
},
|
||||
"nodes": [n.id for n in self.nodes],
|
||||
"edges": [e.id for e in self.edges],
|
||||
"entry_node": self.entry_node,
|
||||
"entry_points": self.entry_points,
|
||||
"pause_nodes": self.pause_nodes,
|
||||
"terminal_nodes": self.terminal_nodes,
|
||||
"client_facing_nodes": [n.id for n in self.nodes if n.client_facing],
|
||||
}
|
||||
|
||||
def validate(self):
|
||||
"""Validate agent structure."""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
node_ids = {node.id for node in self.nodes}
|
||||
for edge in self.edges:
|
||||
if edge.source not in node_ids:
|
||||
errors.append(f"Edge {edge.id}: source '{edge.source}' not found")
|
||||
if edge.target not in node_ids:
|
||||
errors.append(f"Edge {edge.id}: target '{edge.target}' not found")
|
||||
|
||||
if self.entry_node not in node_ids:
|
||||
errors.append(f"Entry node '{self.entry_node}' not found")
|
||||
|
||||
for terminal in self.terminal_nodes:
|
||||
if terminal not in node_ids:
|
||||
errors.append(f"Terminal node '{terminal}' not found")
|
||||
|
||||
for ep_id, node_id in self.entry_points.items():
|
||||
if node_id not in node_ids:
|
||||
errors.append(
|
||||
f"Entry point '{ep_id}' references unknown node '{node_id}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
# Create default instance
|
||||
default_agent = SDRAgent()
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Runtime configuration for SDR Agent."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from framework.config import RuntimeConfig
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
name: str = "SDR Agent"
|
||||
version: str = "1.0.0"
|
||||
description: str = (
|
||||
"Automate sales development outreach using AI-powered contact scoring, "
|
||||
"scam detection, and personalized message generation. "
|
||||
"Score contacts by priority, filter suspicious profiles, generate "
|
||||
"personalized outreach messages, and create Gmail drafts for review."
|
||||
)
|
||||
intro_message: str = (
|
||||
"Hi! I'm your SDR (Sales Development Representative) assistant. "
|
||||
"Provide a list of contacts and your outreach goal, and I'll "
|
||||
"score them by priority, filter out suspicious profiles, generate "
|
||||
"personalized messages for each contact, and create Gmail drafts "
|
||||
"for your review. I never send emails automatically — you stay in control. "
|
||||
"To get started, share your contact list and tell me about your outreach goal!"
|
||||
)
|
||||
|
||||
|
||||
metadata = AgentMetadata()
|
||||
@@ -0,0 +1,97 @@
|
||||
[
|
||||
{
|
||||
"name": "Sarah Chen",
|
||||
"email": "sarah.chen@techcorp.io",
|
||||
"company": "TechCorp",
|
||||
"title": "Learning & Development Manager",
|
||||
"linkedin_url": "https://linkedin.com/in/sarah-chen-ld",
|
||||
"connection_degree": "2nd",
|
||||
"is_alumni": true,
|
||||
"school_name": "University of Western Ontario",
|
||||
"connections_count": 843,
|
||||
"mutual_connections": 7,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
},
|
||||
{
|
||||
"name": "James Okafor",
|
||||
"email": "james.okafor@edventure.co",
|
||||
"company": "EdVenture",
|
||||
"title": "Instructional Designer",
|
||||
"linkedin_url": "https://linkedin.com/in/james-okafor-id",
|
||||
"connection_degree": "1st",
|
||||
"is_alumni": false,
|
||||
"connections_count": 621,
|
||||
"mutual_connections": 12,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
},
|
||||
{
|
||||
"name": "Emily Zhao",
|
||||
"email": "emily.zhao@univedu.ca",
|
||||
"company": "UniEdu",
|
||||
"title": "Director of Digital Learning",
|
||||
"linkedin_url": "https://linkedin.com/in/emily-zhao-dl",
|
||||
"connection_degree": "2nd",
|
||||
"is_alumni": true,
|
||||
"school_name": "University of Western Ontario",
|
||||
"connections_count": 1204,
|
||||
"mutual_connections": 3,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true,
|
||||
"active_job_posting": true
|
||||
},
|
||||
{
|
||||
"name": "Marcus Williams",
|
||||
"email": "marcus@growthsales.io",
|
||||
"company": "GrowthSales",
|
||||
"title": "CEO",
|
||||
"linkedin_url": "https://linkedin.com/in/marcus-williams-ceo",
|
||||
"connection_degree": "3rd",
|
||||
"is_alumni": false,
|
||||
"connections_count": 6300,
|
||||
"mutual_connections": 0,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": false
|
||||
},
|
||||
{
|
||||
"name": "Priya Patel",
|
||||
"email": "",
|
||||
"company": "FutureLearn Inc.",
|
||||
"title": "EdTech Product Manager",
|
||||
"linkedin_url": "https://linkedin.com/in/priya-patel-edtech",
|
||||
"connection_degree": "2nd",
|
||||
"is_alumni": false,
|
||||
"connections_count": 512,
|
||||
"mutual_connections": 5,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
},
|
||||
{
|
||||
"name": "Alex Johnson",
|
||||
"email": "alex@bizopp.biz",
|
||||
"company": "Biz Opportunity Global",
|
||||
"title": "Entrepreneur",
|
||||
"linkedin_url": "https://linkedin.com/in/alex-johnson-biz",
|
||||
"connection_degree": "3rd",
|
||||
"is_alumni": false,
|
||||
"connections_count": 38,
|
||||
"mutual_connections": 0,
|
||||
"has_photo": false,
|
||||
"company_domain_verified": false
|
||||
},
|
||||
{
|
||||
"name": "Natalie Brown",
|
||||
"email": "natalie.brown@learningpro.com",
|
||||
"company": "LearningPro",
|
||||
"title": "HR Learning Specialist",
|
||||
"linkedin_url": "https://linkedin.com/in/natalie-brown-hr",
|
||||
"connection_degree": "1st",
|
||||
"is_alumni": true,
|
||||
"school_name": "University of Western Ontario",
|
||||
"connections_count": 389,
|
||||
"mutual_connections": 9,
|
||||
"has_photo": true,
|
||||
"company_domain_verified": true
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,270 @@
|
||||
{
|
||||
"original_draft": {
|
||||
"agent_name": "sdr_agent",
|
||||
"goal": "Automate sales development outreach: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review.",
|
||||
"description": "",
|
||||
"success_criteria": [
|
||||
"Contacts are correctly scored and ranked by priority factors (alumni status, connection degree, domain verification)",
|
||||
"Suspicious profiles (risk_score >= 7) are correctly identified and excluded from outreach",
|
||||
"Generated messages reference specific profile details (alumni connection, role, company) and match the outreach goal",
|
||||
"Gmail drafts are created for all safe contacts without errors"
|
||||
],
|
||||
"constraints": [
|
||||
"Agent creates Gmail drafts but NEVER sends emails automatically",
|
||||
"Must not process more contacts than the configured max_contacts parameter",
|
||||
"Contacts with risk_score >= 7 must be excluded from outreach"
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"id": "intake",
|
||||
"name": "Intake",
|
||||
"description": "Receive the contact list and outreach goal from the user. Confirm the strategy and batch size before proceeding.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_contacts_from_file"
|
||||
],
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"contacts",
|
||||
"outreach_goal",
|
||||
"max_contacts",
|
||||
"user_background"
|
||||
],
|
||||
"success_criteria": "The user has confirmed the contact list, outreach goal, batch size, and their background. All four keys have been written via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "start",
|
||||
"flowchart_shape": "stadium",
|
||||
"flowchart_color": "#8aad3f"
|
||||
},
|
||||
{
|
||||
"id": "score-contacts",
|
||||
"name": "Score Contacts",
|
||||
"description": "Score and rank each contact from 0 to 100 based on priority factors: alumni status, connection degree, domain verification, mutual connections, and active job postings.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"success_criteria": "Every contact has a priority_score field (0-100) and scored_contacts.jsonl has been written and referenced via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "filter-contacts",
|
||||
"name": "Filter Contacts",
|
||||
"description": "Analyze each contact for authenticity and filter out suspicious profiles. Any contact with a risk score of 7 or higher is skipped.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"scored_contacts"
|
||||
],
|
||||
"output_keys": [
|
||||
"safe_contacts",
|
||||
"filtered_count"
|
||||
],
|
||||
"success_criteria": "Each contact has a risk_score and recommendation field. Contacts with risk_score >= 7 are excluded. safe_contacts.jsonl and filtered_count are set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "personalize",
|
||||
"name": "Personalize",
|
||||
"description": "Generate a personalized outreach message for each contact based on their profile, shared background, and the user's outreach goal.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"safe_contacts",
|
||||
"outreach_goal",
|
||||
"user_background"
|
||||
],
|
||||
"output_keys": [
|
||||
"personalized_contacts"
|
||||
],
|
||||
"success_criteria": "Every safe contact has an outreach_message field of 80-120 words that references a specific hook from their profile. personalized_contacts.jsonl is set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "send-outreach",
|
||||
"name": "Send Outreach",
|
||||
"description": "Create Gmail draft emails for each contact using their personalized message. Drafts are created for human review \u2014 emails are never sent automatically.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"gmail_create_draft",
|
||||
"load_data",
|
||||
"append_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"personalized_contacts",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"drafts_created"
|
||||
],
|
||||
"success_criteria": "A Gmail draft has been created for every safe contact. drafts.jsonl records each draft and drafts_created is set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "database",
|
||||
"flowchart_shape": "cylinder",
|
||||
"flowchart_color": "#508878"
|
||||
},
|
||||
{
|
||||
"id": "report",
|
||||
"name": "Report",
|
||||
"description": "Generate a summary report of the outreach campaign: contacts scored, filtered, messaged, and drafts created. Present to user for review.",
|
||||
"node_type": "event_loop",
|
||||
"tools": [
|
||||
"load_data"
|
||||
],
|
||||
"input_keys": [
|
||||
"drafts_created",
|
||||
"filtered_count",
|
||||
"outreach_goal"
|
||||
],
|
||||
"output_keys": [
|
||||
"summary_report"
|
||||
],
|
||||
"success_criteria": "A campaign summary has been presented to the user listing totals for contacts scored, filtered, messaged, and drafts created. summary_report is set via set_output.",
|
||||
"sub_agents": [],
|
||||
"flowchart_type": "terminal",
|
||||
"flowchart_shape": "stadium",
|
||||
"flowchart_color": "#b5453a"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "edge-0",
|
||||
"source": "intake",
|
||||
"target": "score-contacts",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-1",
|
||||
"source": "score-contacts",
|
||||
"target": "filter-contacts",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-2",
|
||||
"source": "filter-contacts",
|
||||
"target": "personalize",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-3",
|
||||
"source": "personalize",
|
||||
"target": "send-outreach",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-4",
|
||||
"source": "send-outreach",
|
||||
"target": "report",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
},
|
||||
{
|
||||
"id": "edge-5",
|
||||
"source": "report",
|
||||
"target": "intake",
|
||||
"condition": "on_success",
|
||||
"description": "",
|
||||
"label": ""
|
||||
}
|
||||
],
|
||||
"entry_node": "intake",
|
||||
"terminal_nodes": [
|
||||
"report"
|
||||
],
|
||||
"flowchart_legend": {
|
||||
"start": {
|
||||
"shape": "stadium",
|
||||
"color": "#8aad3f"
|
||||
},
|
||||
"terminal": {
|
||||
"shape": "stadium",
|
||||
"color": "#b5453a"
|
||||
},
|
||||
"process": {
|
||||
"shape": "rectangle",
|
||||
"color": "#b5a575"
|
||||
},
|
||||
"decision": {
|
||||
"shape": "diamond",
|
||||
"color": "#d89d26"
|
||||
},
|
||||
"io": {
|
||||
"shape": "parallelogram",
|
||||
"color": "#d06818"
|
||||
},
|
||||
"document": {
|
||||
"shape": "document",
|
||||
"color": "#c4b830"
|
||||
},
|
||||
"database": {
|
||||
"shape": "cylinder",
|
||||
"color": "#508878"
|
||||
},
|
||||
"subprocess": {
|
||||
"shape": "subroutine",
|
||||
"color": "#887a48"
|
||||
},
|
||||
"browser": {
|
||||
"shape": "hexagon",
|
||||
"color": "#cc8850"
|
||||
}
|
||||
}
|
||||
},
|
||||
"flowchart_map": {
|
||||
"intake": [
|
||||
"intake"
|
||||
],
|
||||
"score-contacts": [
|
||||
"score-contacts"
|
||||
],
|
||||
"filter-contacts": [
|
||||
"filter-contacts"
|
||||
],
|
||||
"personalize": [
|
||||
"personalize"
|
||||
],
|
||||
"send-outreach": [
|
||||
"send-outreach"
|
||||
],
|
||||
"report": [
|
||||
"report"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"hive-tools": {
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"python",
|
||||
"mcp_server.py",
|
||||
"--stdio"
|
||||
],
|
||||
"cwd": "../../../tools",
|
||||
"description": "Hive tools MCP server"
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user