Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e581767cab | |||
| 0663ee5950 | |||
| 4b97baa34b | |||
| a89296d397 | |||
| d568912ba2 | |||
| c4d7980058 | |||
| 8549fe8238 | |||
| 2b8d85bb95 | |||
| 07f7801166 | |||
| 1f12a45151 | |||
| 936e02e8e6 | |||
| d59fe1e109 | |||
| 0f0884c2e0 | |||
| 764012c598 | |||
| fd4dc1a69a | |||
| 377cd39c2a | |||
| e92caeef24 | |||
| b7e6226478 | |||
| a995818db2 | |||
| 0772b4d300 | |||
| 684e0d8dc6 | |||
| d284c5d790 | |||
| 7a9b9666c4 | |||
| a852cb91bf | |||
| 2f21e9eb4b | |||
| 8390ef8731 | |||
| 8d21479c24 | |||
| 965dec3ba1 | |||
| 7992b862c2 | |||
| 44b3e0eaa2 | |||
| f480fc2b94 | |||
| 2844dbf19f | |||
| 22b7e4b0c3 | |||
| 5413833a69 | |||
| 02e1a4584a | |||
| 520840b1dd | |||
| ee96147336 | |||
| 705cef4dc1 | |||
| ab26e64122 | |||
| f365e219cb | |||
| 01621881c2 | |||
| f7639f8572 | |||
| fc643060ce | |||
| 9aebeb181e | |||
| acbbfaaa79 | |||
| bf170bce10 | |||
| 0a090d058b | |||
| 47bfadaad9 | |||
| d968dcd44c | |||
| 6fdaa9ea50 | |||
| 4d251fbdc2 | |||
| 6acceed288 | |||
| 8dd1d6e3aa | |||
| 1da28644a6 | |||
| 6452fe7fef | |||
| 651d6850a1 | |||
| c7fdc92594 | |||
| 3da04265a6 | |||
| 4c98f0d2d0 | |||
| ae921f6cee | |||
| 6b506a1c08 | |||
| 0c9f4fa97e | |||
| 23c66d1059 | |||
| b9d529d94e | |||
| b799789dbe | |||
| 2cd73dfccc | |||
| 57d77d5479 | |||
| 5814021773 | |||
| 4f4cc9c8ce | |||
| d9c840eee5 | |||
| 88253883a3 | |||
| 6ed6e5b286 | |||
| 30bb0ad5d8 | |||
| cb0845f5ba | |||
| ce2525b59c | |||
| 1f77ec3831 | |||
| 6ab5aa8004 | |||
| 4449cd8ee8 | |||
| 8b60c03a0a | |||
| 0e98023e40 | |||
| 48a54b4ee2 |
@@ -2,14 +2,22 @@ name: Bounty completed
|
||||
description: Awards points and notifies Discord when a bounty PR is merged
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr_number:
|
||||
description: "PR number to process (for missed bounties)"
|
||||
required: true
|
||||
type: number
|
||||
|
||||
jobs:
|
||||
bounty-notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:')
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:'))
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
@@ -32,6 +40,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_NUMBER: ${{ inputs.pr_number || github.event.pull_request.number }}
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
name: Link Discord account
|
||||
description: Auto-creates a PR to add contributor to contributors.yml when a link-discord issue is opened
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
link-discord:
|
||||
if: contains(github.event.issue.labels.*.name, 'link-discord')
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 2
|
||||
permissions:
|
||||
contents: write
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Parse issue and update contributors.yml
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
|
||||
const issue = context.payload.issue;
|
||||
const githubUsername = issue.user.login;
|
||||
|
||||
// Parse the issue body for form fields
|
||||
const body = issue.body || '';
|
||||
|
||||
// Extract Discord ID — look for the numeric value after the "Discord User ID" heading
|
||||
const discordMatch = body.match(/### Discord User ID\s*\n\s*(\d{17,20})/);
|
||||
if (!discordMatch) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `Could not find a valid Discord ID in the issue body. Please make sure you entered a numeric ID (17-20 digits), not a username.\n\nExample: \`123456789012345678\``
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'not_planned'
|
||||
});
|
||||
return;
|
||||
}
|
||||
const discordId = discordMatch[1];
|
||||
|
||||
// Extract display name (optional)
|
||||
const nameMatch = body.match(/### Display Name \(optional\)\s*\n\s*(.+)/);
|
||||
const displayName = nameMatch ? nameMatch[1].trim() : '';
|
||||
|
||||
// Check if user already exists
|
||||
const yml = fs.readFileSync('contributors.yml', 'utf-8');
|
||||
if (yml.includes(`github: ${githubUsername}`)) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `@${githubUsername} is already in \`contributors.yml\`. If you need to update your Discord ID, please edit the file directly via PR.`
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'completed'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Append entry to contributors.yml
|
||||
let entry = ` - github: ${githubUsername}\n discord: "${discordId}"`;
|
||||
if (displayName && displayName !== '_No response_') {
|
||||
entry += `\n name: ${displayName}`;
|
||||
}
|
||||
entry += '\n';
|
||||
|
||||
const updated = yml.trimEnd() + '\n' + entry;
|
||||
fs.writeFileSync('contributors.yml', updated);
|
||||
|
||||
// Set outputs for commit step
|
||||
core.exportVariable('GITHUB_USERNAME', githubUsername);
|
||||
core.exportVariable('DISCORD_ID', discordId);
|
||||
core.exportVariable('ISSUE_NUMBER', issue.number.toString());
|
||||
|
||||
- name: Create PR
|
||||
run: |
|
||||
# Check if there are changes
|
||||
if git diff --quiet contributors.yml; then
|
||||
echo "No changes to contributors.yml"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
BRANCH="docs/link-discord-${GITHUB_USERNAME}"
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git checkout -b "$BRANCH"
|
||||
git add contributors.yml
|
||||
git commit -m "docs: link @${GITHUB_USERNAME} to Discord"
|
||||
git push origin "$BRANCH"
|
||||
|
||||
gh pr create \
|
||||
--title "docs: link @${GITHUB_USERNAME} to Discord" \
|
||||
--body "Adds @${GITHUB_USERNAME} (Discord \`${DISCORD_ID}\`) to \`contributors.yml\` for bounty XP tracking.
|
||||
|
||||
Closes #${ISSUE_NUMBER}" \
|
||||
--base main \
|
||||
--head "$BRANCH" \
|
||||
--label "link-discord"
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Notify on issue
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const username = process.env.GITHUB_USERNAME;
|
||||
const issueNumber = parseInt(process.env.ISSUE_NUMBER);
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issueNumber,
|
||||
body: `A PR has been created to link your account. A maintainer will merge it shortly — once merged, you'll receive XP and Discord pings when your bounty PRs are merged.`
|
||||
});
|
||||
@@ -35,6 +35,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
SINCE_DATE: ${{ github.event.inputs.since_date || '' }}
|
||||
|
||||
@@ -41,7 +41,8 @@ Generate a swarm of worker agents with a coding agent(queen) that control them.
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
https://github.com/user-attachments/assets/aad3a035-e7b3-4cac-b13d-4a83c7002c30
|
||||
|
||||
https://github.com/user-attachments/assets/bf10edc3-06ba-48b6-98ba-d069b15fb69d
|
||||
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Identity mapping: GitHub username -> Discord ID
|
||||
#
|
||||
# This file links GitHub accounts to Discord accounts for the
|
||||
# Integration Bounty Program. When a bounty PR is merged, the
|
||||
# GitHub Action uses this file to ping the contributor on Discord.
|
||||
#
|
||||
# HOW TO ADD YOURSELF:
|
||||
# Open a "Link Discord Account" issue:
|
||||
# https://github.com/aden-hive/hive/issues/new?template=link-discord.yml
|
||||
# A GitHub Action will automatically add your entry here.
|
||||
#
|
||||
# To find your Discord ID:
|
||||
# 1. Open Discord Settings > Advanced > Enable Developer Mode
|
||||
# 2. Right-click your name > Copy User ID
|
||||
#
|
||||
# Format:
|
||||
# - github: your-github-username
|
||||
# discord: "your-discord-id" # quotes required (it's a number)
|
||||
# name: Your Display Name # optional
|
||||
|
||||
contributors:
|
||||
# - github: example-user
|
||||
# discord: "123456789012345678"
|
||||
# name: Example User
|
||||
- github: TimothyZhang7
|
||||
discord: "408460790061072384"
|
||||
name: Timothy@Aden
|
||||
@@ -0,0 +1,583 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Antigravity authentication CLI.
|
||||
|
||||
Implements OAuth2 flow for Google's Antigravity Code Assist gateway.
|
||||
Credentials are stored in ~/.hive/antigravity-accounts.json.
|
||||
|
||||
Usage:
|
||||
python -m antigravity_auth auth account add
|
||||
python -m antigravity_auth auth account list
|
||||
python -m antigravity_auth auth account remove <email>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth endpoints
|
||||
_OAUTH_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Scopes for Antigravity/Cloud Code Assist
|
||||
_OAUTH_SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
]
|
||||
|
||||
# Credentials file path in ~/.hive/
|
||||
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
|
||||
# Default project ID
|
||||
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
|
||||
_DEFAULT_REDIRECT_PORT = 51121
|
||||
|
||||
# OAuth credentials fetched from the opencode-antigravity-auth project.
|
||||
# This project reverse-engineered and published the public OAuth credentials
|
||||
# for Google's Antigravity/Cloud Code Assist API.
|
||||
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
|
||||
_CREDENTIALS_URL = (
|
||||
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
|
||||
)
|
||||
|
||||
# Cached credentials fetched from public source
|
||||
_cached_client_id: str | None = None
|
||||
_cached_client_secret: str | None = None
|
||||
|
||||
|
||||
def _fetch_credentials_from_public_source() -> tuple[str | None, str | None]:
|
||||
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
|
||||
global _cached_client_id, _cached_client_secret
|
||||
if _cached_client_id and _cached_client_secret:
|
||||
return _cached_client_id, _cached_client_secret
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
_CREDENTIALS_URL, headers={"User-Agent": "Hive-Antigravity-Auth/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
content = resp.read().decode("utf-8")
|
||||
import re
|
||||
|
||||
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
|
||||
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
|
||||
if id_match:
|
||||
_cached_client_id = id_match.group(1)
|
||||
if secret_match:
|
||||
_cached_client_secret = secret_match.group(1)
|
||||
return _cached_client_id, _cached_client_secret
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to fetch credentials from public source: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_client_id() -> str:
|
||||
"""Get OAuth client ID from env, config, or public source."""
|
||||
env_id = os.environ.get("ANTIGRAVITY_CLIENT_ID")
|
||||
if env_id:
|
||||
return env_id
|
||||
|
||||
# Try hive config
|
||||
hive_cfg = Path.home() / ".hive" / "configuration.json"
|
||||
if hive_cfg.exists():
|
||||
try:
|
||||
with open(hive_cfg) as f:
|
||||
cfg = json.load(f)
|
||||
cfg_id = cfg.get("llm", {}).get("antigravity_client_id")
|
||||
if cfg_id:
|
||||
return cfg_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch from public source
|
||||
client_id, _ = _fetch_credentials_from_public_source()
|
||||
if client_id:
|
||||
return client_id
|
||||
|
||||
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
|
||||
|
||||
|
||||
def get_client_secret() -> str | None:
|
||||
"""Get OAuth client secret from env, config, or public source."""
|
||||
secret = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
|
||||
if secret:
|
||||
return secret
|
||||
|
||||
# Try to read from hive config
|
||||
hive_cfg = Path.home() / ".hive" / "configuration.json"
|
||||
if hive_cfg.exists():
|
||||
try:
|
||||
with open(hive_cfg) as f:
|
||||
cfg = json.load(f)
|
||||
secret = cfg.get("llm", {}).get("antigravity_client_secret")
|
||||
if secret:
|
||||
return secret
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch from public source (npm package on GitHub)
|
||||
_, secret = _fetch_credentials_from_public_source()
|
||||
return secret
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find an available local port."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
class OAuthCallbackHandler(BaseHTTPRequestHandler):
|
||||
"""Handle OAuth callback from browser."""
|
||||
|
||||
auth_code: str | None = None
|
||||
state: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
pass # Suppress default logging
|
||||
|
||||
def do_GET(self) -> None:
|
||||
parsed = urllib.parse.urlparse(self.path)
|
||||
|
||||
if parsed.path == "/oauth-callback":
|
||||
query = urllib.parse.parse_qs(parsed.query)
|
||||
|
||||
if "error" in query:
|
||||
self.error = query["error"][0]
|
||||
self._send_response("Authentication failed. You can close this window.")
|
||||
return
|
||||
|
||||
if "code" in query and "state" in query:
|
||||
OAuthCallbackHandler.auth_code = query["code"][0]
|
||||
OAuthCallbackHandler.state = query["state"][0]
|
||||
self._send_response(
|
||||
"Authentication successful! You can close this window "
|
||||
"and return to the terminal."
|
||||
)
|
||||
return
|
||||
|
||||
self._send_response("Waiting for authentication...")
|
||||
|
||||
def _send_response(self, message: str) -> None:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Antigravity Auth</title></head>
|
||||
<body style="font-family: system-ui; display: flex; align-items: center;
|
||||
justify-content: center; height: 100vh; margin: 0; background: #1a1a2e;
|
||||
color: #eee;">
|
||||
<div style="text-align: center;">
|
||||
<h2>{message}</h2>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
self.wfile.write(html.encode())
|
||||
|
||||
|
||||
def wait_for_callback(port: int, timeout: int = 300) -> tuple[str | None, str | None, str | None]:
|
||||
"""Start local server and wait for OAuth callback."""
|
||||
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
|
||||
server.timeout = 1
|
||||
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if OAuthCallbackHandler.auth_code:
|
||||
return (
|
||||
OAuthCallbackHandler.auth_code,
|
||||
OAuthCallbackHandler.state,
|
||||
OAuthCallbackHandler.error,
|
||||
)
|
||||
server.handle_request()
|
||||
|
||||
return None, None, "timeout"
|
||||
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
code: str, redirect_uri: str, client_id: str, client_secret: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Exchange authorization code for tokens."""
|
||||
data = {
|
||||
"code": code,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
body = urllib.parse.urlencode(data).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return json.loads(resp.read())
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_user_email(access_token: str) -> str | None:
|
||||
"""Get user email from Google API."""
|
||||
req = urllib.request.Request(
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data.get("email")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_accounts() -> dict[str, Any]:
|
||||
"""Load existing accounts from file."""
|
||||
if not _ACCOUNTS_FILE.exists():
|
||||
return {"schemaVersion": 4, "accounts": []}
|
||||
try:
|
||||
with open(_ACCOUNTS_FILE) as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {"schemaVersion": 4, "accounts": []}
|
||||
|
||||
|
||||
def save_accounts(data: dict[str, Any]) -> None:
|
||||
"""Save accounts to file."""
|
||||
_ACCOUNTS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(_ACCOUNTS_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
logger.info(f"Saved credentials to {_ACCOUNTS_FILE}")
|
||||
|
||||
|
||||
def validate_credentials(access_token: str, project_id: str = _DEFAULT_PROJECT_ID) -> bool:
|
||||
"""Test if credentials work by making a simple API call to Antigravity.
|
||||
|
||||
Returns True if credentials are valid, False otherwise.
|
||||
"""
|
||||
endpoint = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
body = {
|
||||
"project": project_id,
|
||||
"model": "gemini-3-flash",
|
||||
"request": {
|
||||
"contents": [{"role": "user", "parts": [{"text": "hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
},
|
||||
"requestType": "agent",
|
||||
"userAgent": "antigravity",
|
||||
"requestId": "validation-test",
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) Antigravity/1.18.3"
|
||||
),
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
}
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{endpoint}/v1internal:generateContent",
|
||||
data=json.dumps(body).encode("utf-8"),
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
json.loads(resp.read())
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def refresh_access_token(
|
||||
refresh_token: str, client_id: str, client_secret: str | None
|
||||
) -> dict | None:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
body = urllib.parse.urlencode(data).encode()
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return json.loads(resp.read())
|
||||
except Exception as e:
|
||||
logger.debug(f"Token refresh failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def cmd_account_add(args: argparse.Namespace) -> int:
|
||||
"""Add a new Antigravity account via OAuth2.
|
||||
|
||||
First checks if valid credentials already exist. If so, validates them
|
||||
and skips OAuth if they work. Otherwise, proceeds with OAuth flow.
|
||||
"""
|
||||
client_id = get_client_id()
|
||||
client_secret = get_client_secret()
|
||||
|
||||
# Check if credentials already exist
|
||||
accounts_data = load_accounts()
|
||||
accounts = accounts_data.get("accounts", [])
|
||||
|
||||
if accounts:
|
||||
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
|
||||
access_token = account.get("access")
|
||||
refresh_token_str = account.get("refresh", "")
|
||||
refresh_token = refresh_token_str.split("|")[0] if refresh_token_str else None
|
||||
project_id = (
|
||||
refresh_token_str.split("|")[1] if "|" in refresh_token_str else _DEFAULT_PROJECT_ID
|
||||
)
|
||||
email = account.get("email", "unknown")
|
||||
expires_ms = account.get("expires", 0)
|
||||
expires_at = expires_ms / 1000.0 if expires_ms else 0.0
|
||||
|
||||
# Check if token is expired or near expiry
|
||||
if access_token and expires_at and time.time() < expires_at - 60:
|
||||
# Token still valid, test it
|
||||
logger.info(f"Found existing credentials for: {email}")
|
||||
logger.info("Validating existing credentials...")
|
||||
if validate_credentials(access_token, project_id):
|
||||
logger.info("✓ Credentials valid! Skipping OAuth.")
|
||||
return 0
|
||||
else:
|
||||
logger.info("Credentials failed validation, refreshing...")
|
||||
elif refresh_token:
|
||||
logger.info(f"Found expired credentials for: {email}")
|
||||
logger.info("Attempting token refresh...")
|
||||
|
||||
tokens = refresh_access_token(refresh_token, client_id, client_secret)
|
||||
if tokens:
|
||||
new_access = tokens.get("access_token")
|
||||
expires_in = tokens.get("expires_in", 3600)
|
||||
if new_access:
|
||||
# Update the account
|
||||
account["access"] = new_access
|
||||
account["expires"] = int((time.time() + expires_in) * 1000)
|
||||
accounts_data["last_refresh"] = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ", time.gmtime()
|
||||
)
|
||||
save_accounts(accounts_data)
|
||||
|
||||
# Validate the refreshed token
|
||||
logger.info("Validating refreshed credentials...")
|
||||
if validate_credentials(new_access, project_id):
|
||||
logger.info("✓ Credentials refreshed and validated!")
|
||||
return 0
|
||||
else:
|
||||
logger.info("Refreshed token failed validation, proceeding with OAuth...")
|
||||
else:
|
||||
logger.info("Token refresh failed, proceeding with OAuth...")
|
||||
|
||||
# No valid credentials, proceed with OAuth
|
||||
if not client_secret:
|
||||
logger.warning(
|
||||
"No client secret configured. Token refresh may fail.\n"
|
||||
"Set ANTIGRAVITY_CLIENT_SECRET env var or add "
|
||||
"'antigravity_client_secret' to ~/.hive/configuration.json"
|
||||
)
|
||||
|
||||
# Use fixed port and path matching Google's expected OAuth redirect URI
|
||||
port = _DEFAULT_REDIRECT_PORT
|
||||
redirect_uri = f"http://localhost:{port}/oauth-callback"
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(16)
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(_OAUTH_SCOPES),
|
||||
"state": state,
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
auth_url = f"{_OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
logger.info("Opening browser for authentication...")
|
||||
logger.info(f"If the browser doesn't open, visit: {auth_url}\n")
|
||||
|
||||
# Open browser
|
||||
webbrowser.open(auth_url)
|
||||
|
||||
# Wait for callback
|
||||
logger.info(f"Listening for callback on port {port}...")
|
||||
code, received_state, error = wait_for_callback(port)
|
||||
|
||||
if error:
|
||||
logger.error(f"Authentication failed: {error}")
|
||||
return 1
|
||||
|
||||
if not code:
|
||||
logger.error("No authorization code received")
|
||||
return 1
|
||||
|
||||
if received_state != state:
|
||||
logger.error("State mismatch - possible CSRF attack")
|
||||
return 1
|
||||
|
||||
# Exchange code for tokens
|
||||
logger.info("Exchanging authorization code for tokens...")
|
||||
tokens = exchange_code_for_tokens(code, redirect_uri, client_id, client_secret)
|
||||
|
||||
if not tokens:
|
||||
return 1
|
||||
|
||||
access_token = tokens.get("access_token")
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
expires_in = tokens.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
logger.error("No access token in response")
|
||||
return 1
|
||||
|
||||
# Get user email
|
||||
email = get_user_email(access_token)
|
||||
if email:
|
||||
logger.info(f"Authenticated as: {email}")
|
||||
|
||||
# Load existing accounts and add/update
|
||||
accounts_data = load_accounts()
|
||||
accounts = accounts_data.get("accounts", [])
|
||||
|
||||
# Build new account entry (V4 schema)
|
||||
expires_ms = int((time.time() + expires_in) * 1000)
|
||||
refresh_entry = f"{refresh_token}|{_DEFAULT_PROJECT_ID}"
|
||||
|
||||
new_account = {
|
||||
"access": access_token,
|
||||
"refresh": refresh_entry,
|
||||
"expires": expires_ms,
|
||||
"email": email,
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# Update existing account or add new one
|
||||
existing_idx = next((i for i, a in enumerate(accounts) if a.get("email") == email), None)
|
||||
if existing_idx is not None:
|
||||
accounts[existing_idx] = new_account
|
||||
logger.info(f"Updated existing account: {email}")
|
||||
else:
|
||||
accounts.append(new_account)
|
||||
logger.info(f"Added new account: {email}")
|
||||
|
||||
accounts_data["accounts"] = accounts
|
||||
accounts_data["schemaVersion"] = 4
|
||||
accounts_data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
|
||||
save_accounts(accounts_data)
|
||||
logger.info("\n✓ Authentication complete!")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_account_list(args: argparse.Namespace) -> int:
|
||||
"""List all stored accounts."""
|
||||
data = load_accounts()
|
||||
accounts = data.get("accounts", [])
|
||||
|
||||
if not accounts:
|
||||
logger.info("No accounts configured.")
|
||||
logger.info("Run 'antigravity auth account add' to add one.")
|
||||
return 0
|
||||
|
||||
logger.info("Configured accounts:\n")
|
||||
for i, account in enumerate(accounts, 1):
|
||||
email = account.get("email", "unknown")
|
||||
enabled = "enabled" if account.get("enabled", True) else "disabled"
|
||||
logger.info(f" {i}. {email} ({enabled})")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_account_remove(args: argparse.Namespace) -> int:
|
||||
"""Remove an account by email."""
|
||||
email = args.email
|
||||
data = load_accounts()
|
||||
accounts = data.get("accounts", [])
|
||||
|
||||
original_len = len(accounts)
|
||||
accounts = [a for a in accounts if a.get("email") != email]
|
||||
|
||||
if len(accounts) == original_len:
|
||||
logger.error(f"No account found with email: {email}")
|
||||
return 1
|
||||
|
||||
data["accounts"] = accounts
|
||||
save_accounts(data)
|
||||
logger.info(f"Removed account: {email}")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Antigravity authentication CLI",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||
|
||||
# auth account add
|
||||
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
|
||||
auth_subparsers = auth_parser.add_subparsers(dest="auth_command")
|
||||
|
||||
account_parser = auth_subparsers.add_parser("account", help="Account management")
|
||||
account_subparsers = account_parser.add_subparsers(dest="account_command")
|
||||
|
||||
add_parser = account_subparsers.add_parser("add", help="Add a new account via OAuth2")
|
||||
add_parser.set_defaults(func=cmd_account_add)
|
||||
|
||||
list_parser = account_subparsers.add_parser("list", help="List configured accounts")
|
||||
list_parser.set_defaults(func=cmd_account_list)
|
||||
|
||||
remove_parser = account_subparsers.add_parser("remove", help="Remove an account")
|
||||
remove_parser.add_argument("email", help="Email of account to remove")
|
||||
remove_parser.set_defaults(func=cmd_account_remove)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
return args.func(args)
|
||||
|
||||
parser.print_help()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -89,6 +89,16 @@ def main():
|
||||
|
||||
register_testing_commands(subparsers)
|
||||
|
||||
# Register skill commands (skill list, skill trust, ...)
|
||||
from framework.skills.cli import register_skill_commands
|
||||
|
||||
register_skill_commands(subparsers)
|
||||
|
||||
# Register debugger commands (debugger)
|
||||
from framework.debugger.cli import register_debugger_commands
|
||||
|
||||
register_debugger_commands(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
|
||||
@@ -61,6 +61,150 @@ def get_preferred_model() -> str:
|
||||
return "anthropic/claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
def get_preferred_worker_model() -> str | None:
|
||||
"""Return the user's preferred worker LLM model, or None if not configured.
|
||||
|
||||
Reads from the ``worker_llm`` section of ~/.hive/configuration.json.
|
||||
Returns None when no worker-specific model is set, so callers can
|
||||
fall back to the default (queen) model via ``get_preferred_model()``.
|
||||
"""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm.get("provider") and worker_llm.get("model"):
|
||||
provider = str(worker_llm["provider"])
|
||||
model = str(worker_llm["model"]).strip()
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_api_key() -> str | None:
|
||||
"""Return the API key for the worker LLM, falling back to the default key."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_key()
|
||||
|
||||
# Worker-specific subscription / env var
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_claude_code_token
|
||||
|
||||
token = get_claude_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_codex_token
|
||||
|
||||
token = get_codex_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_kimi_code_token
|
||||
|
||||
token = get_kimi_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_antigravity_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_antigravity_token
|
||||
|
||||
token = get_antigravity_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
api_key_env_var = worker_llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
|
||||
# Fall back to default key
|
||||
return get_api_key()
|
||||
|
||||
|
||||
def get_worker_api_base() -> str | None:
|
||||
"""Return the api_base for the worker LLM, falling back to the default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_base()
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
return "https://chatgpt.com/backend-api/codex"
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
return "https://api.kimi.com/coding"
|
||||
if worker_llm.get("use_antigravity_subscription"):
|
||||
# Antigravity uses AntigravityProvider directly — no api_base needed.
|
||||
return None
|
||||
if worker_llm.get("api_base"):
|
||||
return worker_llm["api_base"]
|
||||
if str(worker_llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_llm_extra_kwargs() -> dict[str, Any]:
|
||||
"""Return extra kwargs for the worker LLM provider."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_llm_extra_kwargs()
|
||||
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
return {
|
||||
"extra_headers": {"authorization": f"Bearer {api_key}"},
|
||||
}
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
try:
|
||||
from framework.runner.runner import get_codex_account_id
|
||||
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
headers["ChatGPT-Account-Id"] = account_id
|
||||
except ImportError:
|
||||
pass
|
||||
return {
|
||||
"extra_headers": headers,
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def get_worker_max_tokens() -> int:
|
||||
"""Return max_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_tokens" in worker_llm:
|
||||
return worker_llm["max_tokens"]
|
||||
return get_max_tokens()
|
||||
|
||||
|
||||
def get_worker_max_context_tokens() -> int:
|
||||
"""Return max_context_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_context_tokens" in worker_llm:
|
||||
return worker_llm["max_context_tokens"]
|
||||
return get_max_context_tokens()
|
||||
|
||||
|
||||
def get_max_tokens() -> int:
|
||||
"""Return the configured max_tokens, falling back to DEFAULT_MAX_TOKENS."""
|
||||
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
@@ -120,6 +264,17 @@ def get_api_key() -> str | None:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Antigravity subscription: read OAuth token from accounts JSON
|
||||
if llm.get("use_antigravity_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_antigravity_token
|
||||
|
||||
token = get_antigravity_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Standard env-var path (covers ZAI Code and all API-key providers)
|
||||
api_key_env_var = llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
@@ -127,6 +282,86 @@ def get_api_key() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
# OAuth credentials for Antigravity are fetched from the opencode-antigravity-auth project.
|
||||
# This project reverse-engineered and published the public OAuth credentials
|
||||
# for Google's Antigravity/Cloud Code Assist API.
|
||||
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
|
||||
_ANTIGRAVITY_CREDENTIALS_URL = (
|
||||
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
|
||||
)
|
||||
_antigravity_credentials_cache: tuple[str | None, str | None] = (None, None)
|
||||
|
||||
|
||||
def _fetch_antigravity_credentials() -> tuple[str | None, str | None]:
|
||||
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
|
||||
global _antigravity_credentials_cache
|
||||
if _antigravity_credentials_cache[0] and _antigravity_credentials_cache[1]:
|
||||
return _antigravity_credentials_cache
|
||||
|
||||
import re
|
||||
import urllib.request
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
_ANTIGRAVITY_CREDENTIALS_URL, headers={"User-Agent": "Hive/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
content = resp.read().decode("utf-8")
|
||||
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
|
||||
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
|
||||
client_id = id_match.group(1) if id_match else None
|
||||
client_secret = secret_match.group(1) if secret_match else None
|
||||
if client_id and client_secret:
|
||||
_antigravity_credentials_cache = (client_id, client_secret)
|
||||
return client_id, client_secret
|
||||
except Exception as e:
|
||||
logger.debug("Failed to fetch Antigravity credentials from public source: %s", e)
|
||||
return None, None
|
||||
|
||||
|
||||
def get_antigravity_client_id() -> str:
|
||||
"""Return the Antigravity OAuth application client ID.
|
||||
|
||||
Checked in order:
|
||||
1. ``ANTIGRAVITY_CLIENT_ID`` environment variable
|
||||
2. ``llm.antigravity_client_id`` in ~/.hive/configuration.json
|
||||
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
|
||||
"""
|
||||
env = os.environ.get("ANTIGRAVITY_CLIENT_ID")
|
||||
if env:
|
||||
return env
|
||||
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_id")
|
||||
if cfg_val:
|
||||
return cfg_val
|
||||
# Fetch from public source
|
||||
client_id, _ = _fetch_antigravity_credentials()
|
||||
if client_id:
|
||||
return client_id
|
||||
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
|
||||
|
||||
|
||||
def get_antigravity_client_secret() -> str | None:
|
||||
"""Return the Antigravity OAuth client secret.
|
||||
|
||||
Checked in order:
|
||||
1. ``ANTIGRAVITY_CLIENT_SECRET`` environment variable
|
||||
2. ``llm.antigravity_client_secret`` in ~/.hive/configuration.json
|
||||
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
|
||||
|
||||
Returns None when not found — token refresh will be skipped and
|
||||
the caller must use whatever access token is already available.
|
||||
"""
|
||||
env = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
|
||||
if env:
|
||||
return env
|
||||
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_secret") or None
|
||||
if cfg_val:
|
||||
return cfg_val
|
||||
# Fetch from public source
|
||||
_, secret = _fetch_antigravity_credentials()
|
||||
return secret
|
||||
|
||||
|
||||
def get_gcu_enabled() -> bool:
|
||||
"""Return whether GCU (browser automation) is enabled in user config."""
|
||||
return get_hive_config().get("gcu_enabled", True)
|
||||
@@ -149,6 +384,9 @@ def get_api_base() -> str | None:
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
return "https://api.kimi.com/coding"
|
||||
if llm.get("use_antigravity_subscription"):
|
||||
# Antigravity uses AntigravityProvider directly — no api_base needed.
|
||||
return None
|
||||
if llm.get("api_base"):
|
||||
return llm["api_base"]
|
||||
if str(llm.get("provider", "")).lower() == "openrouter":
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
"""CLI command for the LLM debug log viewer."""
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_SCRIPT = Path(__file__).resolve().parents[3] / "scripts" / "llm_debug_log_visualizer.py"
|
||||
|
||||
|
||||
def register_debugger_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
"""Register the ``hive debugger`` command."""
|
||||
parser = subparsers.add_parser(
|
||||
"debugger",
|
||||
help="Open the LLM debug log viewer",
|
||||
description=(
|
||||
"Start a local server that lets you browse LLM debug sessions "
|
||||
"recorded in ~/.hive/llm_logs. Sessions are loaded on demand so "
|
||||
"the browser stays responsive."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--session",
|
||||
help="Execution ID to select initially.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port for the local server (0 = auto-pick a free port).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
help="Directory containing JSONL log files (default: ~/.hive/llm_logs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit-files",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of newest log files to scan (default: 200).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
help="Write a static HTML file instead of starting a server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-open",
|
||||
action="store_true",
|
||||
help="Start the server but do not open a browser.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-tests",
|
||||
action="store_true",
|
||||
help="Show test/mock sessions (hidden by default).",
|
||||
)
|
||||
parser.set_defaults(func=cmd_debugger)
|
||||
|
||||
|
||||
def cmd_debugger(args: argparse.Namespace) -> int:
|
||||
"""Launch the LLM debug log visualizer."""
|
||||
cmd: list[str] = [sys.executable, str(_SCRIPT)]
|
||||
if args.session:
|
||||
cmd += ["--session", args.session]
|
||||
if args.port:
|
||||
cmd += ["--port", str(args.port)]
|
||||
if args.logs_dir:
|
||||
cmd += ["--logs-dir", args.logs_dir]
|
||||
if args.limit_files is not None:
|
||||
cmd += ["--limit-files", str(args.limit_files)]
|
||||
if args.output:
|
||||
cmd += ["--output", args.output]
|
||||
if args.no_open:
|
||||
cmd.append("--no-open")
|
||||
if args.include_tests:
|
||||
cmd.append("--include-tests")
|
||||
return subprocess.call(cmd)
|
||||
@@ -33,6 +33,8 @@ class Message:
|
||||
is_transition_marker: bool = False
|
||||
# True when this message is real human input (from /chat), not a system prompt
|
||||
is_client_input: bool = False
|
||||
# True when message contains an activated skill body (AS-10: never prune)
|
||||
is_skill_content: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
@@ -409,6 +411,7 @@ class NodeConversation:
|
||||
tool_use_id: str,
|
||||
content: str,
|
||||
is_error: bool = False,
|
||||
is_skill_content: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -417,6 +420,7 @@ class NodeConversation:
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
phase_id=self._current_phase,
|
||||
is_skill_content=is_skill_content,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -610,6 +614,8 @@ class NodeConversation:
|
||||
continue
|
||||
if msg.is_error:
|
||||
continue # never prune errors
|
||||
if msg.is_skill_content:
|
||||
continue # never prune activated skill instructions (AS-10)
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
# Tiny results (set_output acks, confirmations) — pruning
|
||||
|
||||
@@ -467,6 +467,8 @@ class EventLoopNode(NodeProtocol):
|
||||
stream_id = ctx.stream_id or ctx.node_id
|
||||
node_id = ctx.node_id
|
||||
execution_id = ctx.execution_id or ""
|
||||
# Store skill dirs for AS-9 file-read interception in _execute_tool
|
||||
self._skill_dirs: list[str] = ctx.skill_dirs
|
||||
|
||||
# Verdict counters for runtime logging
|
||||
_accept_count = _retry_count = _escalate_count = _continue_count = 0
|
||||
@@ -531,12 +533,28 @@ class EventLoopNode(NodeProtocol):
|
||||
_restored_recent_responses = restored.recent_responses
|
||||
_restored_tool_fingerprints = restored.recent_tool_fingerprints
|
||||
|
||||
# Refresh the system prompt with full 3-layer composition.
|
||||
# The stored prompt may be stale after code changes or when
|
||||
# runtime-injected context (e.g. worker identity) has changed.
|
||||
# On resume, we rebuild identity + narrative + focus so the LLM
|
||||
# understands the session history, not just the node directive.
|
||||
from framework.graph.prompt_composer import compose_system_prompt
|
||||
# Refresh the system prompt with full composition including
|
||||
# execution preamble and node-type preamble. The stored
|
||||
# prompt may be stale after code changes or when runtime-
|
||||
# injected context (e.g. worker identity) has changed.
|
||||
from framework.graph.prompt_composer import (
|
||||
EXECUTION_SCOPE_PREAMBLE,
|
||||
compose_system_prompt,
|
||||
)
|
||||
|
||||
_exec_preamble = None
|
||||
if (
|
||||
not ctx.is_subagent_mode
|
||||
and ctx.node_spec.node_type in ("event_loop", "gcu")
|
||||
and ctx.node_spec.output_keys
|
||||
):
|
||||
_exec_preamble = EXECUTION_SCOPE_PREAMBLE
|
||||
|
||||
_node_type_preamble = None
|
||||
if ctx.node_spec.node_type == "gcu":
|
||||
from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT
|
||||
|
||||
_node_type_preamble = GCU_BROWSER_SYSTEM_PROMPT
|
||||
|
||||
_current_prompt = compose_system_prompt(
|
||||
identity_prompt=ctx.identity_prompt or None,
|
||||
@@ -545,6 +563,8 @@ class EventLoopNode(NodeProtocol):
|
||||
accounts_prompt=ctx.accounts_prompt or None,
|
||||
skills_catalog_prompt=ctx.skills_catalog_prompt or None,
|
||||
protocols_prompt=ctx.protocols_prompt or None,
|
||||
execution_preamble=_exec_preamble,
|
||||
node_type_preamble=_node_type_preamble,
|
||||
)
|
||||
if conversation.system_prompt != _current_prompt:
|
||||
conversation.update_system_prompt(_current_prompt)
|
||||
@@ -806,6 +826,13 @@ class EventLoopNode(NodeProtocol):
|
||||
execution_id,
|
||||
extra_data=_iter_meta,
|
||||
)
|
||||
# Sync max_context_tokens from live config so mid-session model
|
||||
# switches are reflected in compaction decisions and the UI bar.
|
||||
from framework.config import get_max_context_tokens as _live_mct
|
||||
|
||||
conversation._max_context_tokens = _live_mct()
|
||||
|
||||
await self._publish_context_usage(ctx, conversation, "iteration_start")
|
||||
|
||||
# 6d. Pre-turn compaction check (tiered)
|
||||
_compacted_this_iter = False
|
||||
@@ -2477,6 +2504,27 @@ class EventLoopNode(NodeProtocol):
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
|
||||
elif tc.tool_name == "delegate_to_sub_agent":
|
||||
# Guard: in continuous mode the LLM may see delegate
|
||||
# calls from a previous node's conversation history and
|
||||
# attempt to re-use the tool on a node that doesn't own
|
||||
# it. Only accept if the tool was actually offered.
|
||||
if not any(t.name == "delegate_to_sub_agent" for t in tools):
|
||||
logger.warning(
|
||||
"[%s] LLM called delegate_to_sub_agent but tool "
|
||||
"was not offered to this node — rejecting",
|
||||
node_id,
|
||||
)
|
||||
result = ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
"ERROR: delegate_to_sub_agent is not available "
|
||||
"on this node. This tool belongs to a different "
|
||||
"node in the workflow."
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
results_by_id[tc.tool_use_id] = result
|
||||
continue
|
||||
# --- Framework-level subagent delegation ---
|
||||
# Queue for parallel execution in Phase 2
|
||||
logger.info(
|
||||
@@ -2726,6 +2774,7 @@ class EventLoopNode(NodeProtocol):
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=result.content,
|
||||
is_error=result.is_error,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
if (
|
||||
tc.tool_name in ("ask_user", "ask_user_multiple")
|
||||
@@ -2834,6 +2883,8 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
|
||||
await self._publish_context_usage(ctx, conversation, "post_tool_results")
|
||||
|
||||
# If the turn requested external input (ask_user or queen handoff),
|
||||
# return immediately so the outer loop can block before judge eval.
|
||||
if user_input_requested or queen_input_requested:
|
||||
@@ -3549,6 +3600,33 @@ class EventLoopNode(NodeProtocol):
|
||||
content=f"No tool executor configured for '{tc.tool_name}'",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
# AS-9: Intercept file-read tools for skill directories — bypass session sandbox
|
||||
_SKILL_READ_TOOLS = {"view_file", "load_data", "read_file"}
|
||||
skill_dirs = getattr(self, "_skill_dirs", [])
|
||||
if tc.tool_name in _SKILL_READ_TOOLS and skill_dirs:
|
||||
_path = tc.tool_input.get("path", "")
|
||||
if _path:
|
||||
import os
|
||||
from pathlib import Path as _Path
|
||||
|
||||
_resolved = os.path.realpath(os.path.abspath(_path))
|
||||
if any(_resolved.startswith(os.path.realpath(d)) for d in skill_dirs):
|
||||
try:
|
||||
_content = _Path(_resolved).read_text(encoding="utf-8")
|
||||
_is_skill_md = _resolved.endswith("SKILL.md")
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=_content,
|
||||
is_skill_content=_is_skill_md, # AS-10: protect SKILL.md reads
|
||||
)
|
||||
except Exception as _exc:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"Could not read skill resource '{_path}': {_exc}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
|
||||
timeout = self._config.tool_call_timeout_seconds
|
||||
|
||||
@@ -3980,6 +4058,12 @@ class EventLoopNode(NodeProtocol):
|
||||
ratio_before = conversation.usage_ratio()
|
||||
phase_grad = getattr(ctx, "continuous_mode", False)
|
||||
|
||||
# Capture pre-compaction message inventory when over budget,
|
||||
# since compaction mutates the conversation in place.
|
||||
pre_inventory: list[dict[str, Any]] | None = None
|
||||
if ratio_before >= 1.0:
|
||||
pre_inventory = self._build_message_inventory(conversation)
|
||||
|
||||
# --- Step 1: Prune old tool results (free, no LLM) ---
|
||||
protect = max(2000, self._config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
@@ -3994,7 +4078,7 @@ class EventLoopNode(NodeProtocol):
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
return
|
||||
|
||||
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
|
||||
@@ -4007,7 +4091,7 @@ class EventLoopNode(NodeProtocol):
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
return
|
||||
|
||||
# --- Step 3: LLM summary compaction ---
|
||||
@@ -4034,7 +4118,7 @@ class EventLoopNode(NodeProtocol):
|
||||
logger.warning("LLM compaction failed: %s", e)
|
||||
|
||||
if not conversation.needs_compaction():
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
return
|
||||
|
||||
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
|
||||
@@ -4048,7 +4132,7 @@ class EventLoopNode(NodeProtocol):
|
||||
keep_recent=1,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
await self._log_compaction(ctx, conversation, ratio_before)
|
||||
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
|
||||
|
||||
# --- LLM compaction with binary-search splitting ----------------------
|
||||
|
||||
@@ -4210,13 +4294,59 @@ class EventLoopNode(NodeProtocol):
|
||||
"re-doing work.\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_message_inventory(
|
||||
conversation: NodeConversation,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a per-message size inventory for debug logging."""
|
||||
inventory: list[dict[str, Any]] = []
|
||||
for m in conversation.messages:
|
||||
content_chars = len(m.content)
|
||||
tc_chars = 0
|
||||
tool_name = None
|
||||
if m.tool_calls:
|
||||
for tc in m.tool_calls:
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
tc_chars += len(args) if isinstance(args, str) else len(json.dumps(args))
|
||||
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
|
||||
tool_name = ", ".join(names)
|
||||
elif m.role == "tool" and m.tool_use_id:
|
||||
for prev in conversation.messages:
|
||||
if prev.tool_calls:
|
||||
for tc in prev.tool_calls:
|
||||
if tc.get("id") == m.tool_use_id:
|
||||
tool_name = tc.get("function", {}).get("name", "?")
|
||||
break
|
||||
if tool_name:
|
||||
break
|
||||
entry: dict[str, Any] = {
|
||||
"seq": m.seq,
|
||||
"role": m.role,
|
||||
"content_chars": content_chars,
|
||||
}
|
||||
if tc_chars:
|
||||
entry["tool_call_args_chars"] = tc_chars
|
||||
if tool_name:
|
||||
entry["tool"] = tool_name
|
||||
if m.is_error:
|
||||
entry["is_error"] = True
|
||||
if m.phase_id:
|
||||
entry["phase"] = m.phase_id
|
||||
if content_chars > 2000:
|
||||
entry["preview"] = m.content[:200] + "…"
|
||||
inventory.append(entry)
|
||||
return inventory
|
||||
|
||||
async def _log_compaction(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
ratio_before: float,
|
||||
pre_inventory: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Log compaction result to runtime logger and event bus."""
|
||||
"""Log compaction result to runtime logger, event bus, and debug file."""
|
||||
import os as _os
|
||||
|
||||
ratio_after = conversation.usage_ratio()
|
||||
before_pct = round(ratio_before * 100)
|
||||
after_pct = round(ratio_after * 100)
|
||||
@@ -4249,19 +4379,103 @@ class EventLoopNode(NodeProtocol):
|
||||
if self._event_bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
event_data: dict[str, Any] = {
|
||||
"level": level,
|
||||
"usage_before": before_pct,
|
||||
"usage_after": after_pct,
|
||||
}
|
||||
if pre_inventory is not None:
|
||||
event_data["message_inventory"] = pre_inventory
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_COMPACTED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"level": level,
|
||||
"usage_before": before_pct,
|
||||
"usage_after": after_pct,
|
||||
},
|
||||
data=event_data,
|
||||
)
|
||||
)
|
||||
|
||||
# Emit post-compaction usage update
|
||||
await self._publish_context_usage(ctx, conversation, "post_compaction")
|
||||
|
||||
# Write detailed debug log to ~/.hive/compaction_log/ when enabled
|
||||
if _os.environ.get("HIVE_COMPACTION_DEBUG"):
|
||||
self._write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory)
|
||||
|
||||
@staticmethod
|
||||
def _write_compaction_debug_log(
|
||||
ctx: NodeContext,
|
||||
before_pct: int,
|
||||
after_pct: int,
|
||||
level: str,
|
||||
inventory: list[dict[str, Any]] | None,
|
||||
) -> None:
|
||||
"""Write detailed compaction analysis to ~/.hive/compaction_log/."""
|
||||
log_dir = Path.home() / ".hive" / "compaction_log"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
|
||||
node_label = ctx.node_id.replace("/", "_")
|
||||
log_path = log_dir / f"{ts}_{node_label}.md"
|
||||
|
||||
lines: list[str] = [
|
||||
f"# Compaction Debug — {ctx.node_id}",
|
||||
f"**Time:** {datetime.now(UTC).isoformat()}",
|
||||
f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)",
|
||||
]
|
||||
if ctx.stream_id:
|
||||
lines.append(f"**Stream:** {ctx.stream_id}")
|
||||
lines.append(f"**Level:** {level}")
|
||||
lines.append(f"**Usage:** {before_pct}% → {after_pct}%")
|
||||
lines.append("")
|
||||
|
||||
if inventory:
|
||||
total_chars = sum(
|
||||
e.get("content_chars", 0) + e.get("tool_call_args_chars", 0) for e in inventory
|
||||
)
|
||||
lines.append(
|
||||
f"## Pre-Compaction Message Inventory "
|
||||
f"({len(inventory)} messages, {total_chars:,} total chars)"
|
||||
)
|
||||
lines.append("")
|
||||
ranked = sorted(
|
||||
inventory,
|
||||
key=lambda e: e.get("content_chars", 0) + e.get("tool_call_args_chars", 0),
|
||||
reverse=True,
|
||||
)
|
||||
lines.append("| # | seq | role | tool | chars | % of total | flags |")
|
||||
lines.append("|---|-----|------|------|------:|------------|-------|")
|
||||
for i, entry in enumerate(ranked, 1):
|
||||
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
|
||||
pct = (chars / total_chars * 100) if total_chars else 0
|
||||
tool = entry.get("tool", "")
|
||||
flags = []
|
||||
if entry.get("is_error"):
|
||||
flags.append("error")
|
||||
if entry.get("phase"):
|
||||
flags.append(f"phase={entry['phase']}")
|
||||
lines.append(
|
||||
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
|
||||
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
|
||||
)
|
||||
|
||||
large = [e for e in ranked if e.get("preview")]
|
||||
if large:
|
||||
lines.append("")
|
||||
lines.append("### Large message previews")
|
||||
for entry in large:
|
||||
lines.append(
|
||||
f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):"
|
||||
)
|
||||
lines.append(f"```\n{entry['preview']}\n```")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
log_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
logger.debug("Compaction debug log written to %s", log_path)
|
||||
except OSError:
|
||||
logger.debug("Failed to write compaction debug log to %s", log_path)
|
||||
|
||||
def _build_emergency_summary(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
@@ -4666,6 +4880,36 @@ class EventLoopNode(NodeProtocol):
|
||||
if result.inject:
|
||||
await conversation.add_user_message(result.inject)
|
||||
|
||||
async def _publish_context_usage(
|
||||
self,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
trigger: str,
|
||||
) -> None:
|
||||
"""Emit a CONTEXT_USAGE_UPDATED event with current context window state."""
|
||||
if not self._event_bus:
|
||||
return
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
estimated = conversation.estimate_tokens()
|
||||
max_tokens = conversation._max_context_tokens
|
||||
ratio = estimated / max_tokens if max_tokens > 0 else 0.0
|
||||
await self._event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_USAGE_UPDATED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"usage_ratio": round(ratio, 4),
|
||||
"usage_pct": round(ratio * 100),
|
||||
"message_count": conversation.message_count,
|
||||
"estimated_tokens": estimated,
|
||||
"max_context_tokens": max_tokens,
|
||||
"trigger": trigger,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def _publish_iteration(
|
||||
self,
|
||||
stream_id: str,
|
||||
@@ -4950,7 +5194,20 @@ class EventLoopNode(NodeProtocol):
|
||||
write_keys=[], # Read-only!
|
||||
)
|
||||
|
||||
# 2b. Set up report callback (one-way channel to parent / event bus)
|
||||
# 2b. Compute instance counter early so node_id is available for the
|
||||
# report callback and the NodeContext. Each delegation to the same
|
||||
# agent_id gets a unique suffix (instance 1 has no suffix for backward
|
||||
# compat; instance 2+ appends ":N").
|
||||
self._subagent_instance_counter.setdefault(agent_id, 0)
|
||||
self._subagent_instance_counter[agent_id] += 1
|
||||
_sa_instance = self._subagent_instance_counter[agent_id]
|
||||
if _sa_instance > 1:
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}:{_sa_instance}"
|
||||
else:
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}"
|
||||
subagent_instance = str(_sa_instance)
|
||||
|
||||
# 2c. Set up report callback (one-way channel to parent / event bus)
|
||||
subagent_reports: list[dict] = []
|
||||
|
||||
async def _report_callback(
|
||||
@@ -4963,7 +5220,7 @@ class EventLoopNode(NodeProtocol):
|
||||
if self._event_bus:
|
||||
await self._event_bus.emit_subagent_report(
|
||||
stream_id=ctx.node_id,
|
||||
node_id=f"{ctx.node_id}:subagent:{agent_id}",
|
||||
node_id=sa_node_id,
|
||||
subagent_id=agent_id,
|
||||
message=message,
|
||||
data=data,
|
||||
@@ -5053,7 +5310,7 @@ class EventLoopNode(NodeProtocol):
|
||||
max_iter = min(self._config.max_iterations, 10)
|
||||
subagent_ctx = NodeContext(
|
||||
runtime=ctx.runtime,
|
||||
node_id=f"{ctx.node_id}:subagent:{agent_id}",
|
||||
node_id=sa_node_id,
|
||||
node_spec=subagent_spec,
|
||||
memory=scoped_memory,
|
||||
input_data={"task": task, **parent_data},
|
||||
@@ -5081,10 +5338,7 @@ class EventLoopNode(NodeProtocol):
|
||||
# Derive a conversation store for the subagent from the parent's store.
|
||||
# Each invocation gets a unique path so that repeated delegate calls
|
||||
# (e.g. one per profile) don't restore a stale completed conversation.
|
||||
self._subagent_instance_counter.setdefault(agent_id, 0)
|
||||
self._subagent_instance_counter[agent_id] += 1
|
||||
subagent_instance = str(self._subagent_instance_counter[agent_id])
|
||||
|
||||
# (Instance counter was computed earlier in step 2b.)
|
||||
subagent_conv_store = None
|
||||
if self._conversation_store is not None:
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
@@ -154,6 +154,7 @@ class GraphExecutor:
|
||||
iteration_metadata_provider: Callable | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -181,6 +182,7 @@ class GraphExecutor:
|
||||
system prompt (for phase switching)
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -204,6 +206,7 @@ class GraphExecutor:
|
||||
self.iteration_metadata_provider = iteration_metadata_provider
|
||||
self.skills_catalog_prompt = skills_catalog_prompt
|
||||
self.protocols_prompt = protocols_prompt
|
||||
self.skill_dirs: list[str] = skill_dirs or []
|
||||
|
||||
if protocols_prompt:
|
||||
self.logger.info(
|
||||
@@ -1845,6 +1848,9 @@ class GraphExecutor:
|
||||
|
||||
existing_underscore = [k for k in memory._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
# Only inject into read_keys when it was already non-empty — an empty
|
||||
# read_keys means "allow all reads" and injecting skill keys would
|
||||
# inadvertently restrict reads to skill keys only.
|
||||
for k in extra_keys:
|
||||
if read_keys and k not in read_keys:
|
||||
read_keys.append(k)
|
||||
@@ -1899,6 +1905,7 @@ class GraphExecutor:
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
|
||||
@@ -167,14 +167,6 @@ class Goal(BaseModel):
|
||||
|
||||
return met_weight >= total_weight * 0.9 # 90% threshold
|
||||
|
||||
def check_constraint(self, constraint_id: str, value: Any) -> bool:
|
||||
"""Check if a specific constraint is satisfied."""
|
||||
for c in self.constraints:
|
||||
if c.id == constraint_id:
|
||||
# This would be expanded with actual evaluation logic
|
||||
return True
|
||||
return True
|
||||
|
||||
def to_prompt_context(self) -> str:
|
||||
"""Generate context string for LLM prompts.
|
||||
|
||||
|
||||
@@ -568,6 +568,7 @@ class NodeContext:
|
||||
# Skill system prompts — injected by the skill discovery pipeline
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
protocols_prompt: str = "" # Default skill operational protocols
|
||||
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
|
||||
|
||||
# Per-iteration metadata provider — when set, EventLoopNode merges
|
||||
# the returned dict into node_loop_iteration event data. Used by
|
||||
|
||||
@@ -152,6 +152,8 @@ def compose_system_prompt(
|
||||
accounts_prompt: str | None = None,
|
||||
skills_catalog_prompt: str | None = None,
|
||||
protocols_prompt: str | None = None,
|
||||
execution_preamble: str | None = None,
|
||||
node_type_preamble: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the multi-layer system prompt.
|
||||
|
||||
@@ -162,6 +164,10 @@ def compose_system_prompt(
|
||||
accounts_prompt: Connected accounts block (sits between identity and narrative).
|
||||
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
|
||||
protocols_prompt: Default skill operational protocols section.
|
||||
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
|
||||
(prepended before focus so the LLM knows its pipeline scope).
|
||||
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
|
||||
best-practices prompt (prepended before focus).
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present, plus current datetime.
|
||||
@@ -188,6 +194,15 @@ def compose_system_prompt(
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
|
||||
# Execution scope preamble (worker nodes — tells the LLM it is one
|
||||
# step in a multi-node pipeline and should not overreach)
|
||||
if execution_preamble:
|
||||
parts.append(f"\n{execution_preamble}")
|
||||
|
||||
# Node-type preamble (e.g. GCU browser best-practices)
|
||||
if node_type_preamble:
|
||||
parts.append(f"\n{node_type_preamble}")
|
||||
|
||||
# Layer 3: Focus (current phase directive)
|
||||
if focus_prompt:
|
||||
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
|
||||
|
||||
@@ -0,0 +1,706 @@
|
||||
"""Antigravity (Google internal Cloud Code Assist) LLM provider.
|
||||
|
||||
Antigravity is Google's unified gateway API that routes requests to Gemini,
|
||||
Claude, and GPT-OSS models through a single Gemini-style interface. It is
|
||||
NOT the public ``generativelanguage.googleapis.com`` API.
|
||||
|
||||
Authentication uses Google OAuth2. Token refresh is done directly with the
|
||||
OAuth client secret — no local proxy required.
|
||||
|
||||
Credential sources (checked in order):
|
||||
1. ``~/.hive/antigravity-accounts.json`` (native OAuth implementation)
|
||||
2. Antigravity IDE SQLite state DB (macOS / Linux)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Fallback order: daily sandbox → autopush sandbox → production
|
||||
_ENDPOINTS = [
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com",
|
||||
"https://autopush-cloudcode-pa.sandbox.googleapis.com",
|
||||
"https://cloudcode-pa.googleapis.com",
|
||||
]
|
||||
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
|
||||
_TOKEN_REFRESH_BUFFER_SECS = 60
|
||||
|
||||
# Credentials file in ~/.hive/ (native implementation)
|
||||
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
_IDE_STATE_DB_MAC = (
|
||||
Path.home()
|
||||
/ "Library"
|
||||
/ "Application Support"
|
||||
/ "Antigravity"
|
||||
/ "User"
|
||||
/ "globalStorage"
|
||||
/ "state.vscdb"
|
||||
)
|
||||
_IDE_STATE_DB_LINUX = (
|
||||
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
|
||||
)
|
||||
_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
|
||||
|
||||
_BASE_HEADERS: dict[str, str] = {
|
||||
# Mimic the Antigravity Electron app so the API accepts the request.
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Antigravity/1.18.3 Chrome/138.0.7204.235 "
|
||||
"Electron/37.3.1 Safari/537.36"
|
||||
),
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"Client-Metadata": '{"ideType":"ANTIGRAVITY","platform":"MACOS","pluginType":"GEMINI"}',
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_from_json_file() -> tuple[str | None, str | None, str, float]:
|
||||
"""Read credentials from JSON accounts file.
|
||||
|
||||
Reads from ~/.hive/antigravity-accounts.json.
|
||||
|
||||
Returns ``(access_token | None, refresh_token | None, project_id, expires_at)``.
|
||||
``expires_at`` is a Unix timestamp (seconds); 0.0 means unknown.
|
||||
"""
|
||||
if not _ACCOUNTS_FILE.exists():
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
try:
|
||||
with open(_ACCOUNTS_FILE, encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
logger.debug("Failed to read Antigravity accounts file: %s", exc)
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
|
||||
accounts = data.get("accounts", [])
|
||||
if not accounts:
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
|
||||
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
|
||||
schema_version = data.get("schemaVersion", 1)
|
||||
|
||||
if schema_version >= 4:
|
||||
# V4 schema: refresh = "refreshToken|projectId[|managedProjectId]"
|
||||
refresh_str = account.get("refresh", "")
|
||||
parts = refresh_str.split("|") if refresh_str else []
|
||||
refresh_token: str | None = parts[0] if parts else None
|
||||
project_id = parts[1] if len(parts) >= 2 and parts[1] else _DEFAULT_PROJECT_ID
|
||||
|
||||
access_token: str | None = account.get("access")
|
||||
expires_ms: int = account.get("expires", 0)
|
||||
expires_at = float(expires_ms) / 1000.0 if expires_ms else 0.0
|
||||
|
||||
# Treat near-expiry tokens as absent so _ensure_token() triggers a refresh.
|
||||
if access_token and expires_at and time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
|
||||
access_token = None
|
||||
expires_at = 0.0
|
||||
|
||||
return access_token, refresh_token, project_id, expires_at
|
||||
else:
|
||||
# V1–V3 schema: plain accessToken / refreshToken fields
|
||||
access_token = account.get("accessToken")
|
||||
refresh_token = account.get("refreshToken")
|
||||
# Estimate expiry from last_refresh + 1 h
|
||||
last_refresh_str: str | None = data.get("last_refresh")
|
||||
expires_at = 0.0
|
||||
if last_refresh_str:
|
||||
try:
|
||||
from datetime import datetime # noqa: PLC0415
|
||||
|
||||
ts = datetime.fromisoformat(last_refresh_str.replace("Z", "+00:00")).timestamp()
|
||||
expires_at = ts + 3600.0
|
||||
if time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
|
||||
access_token = None
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return access_token, refresh_token, _DEFAULT_PROJECT_ID, expires_at
|
||||
|
||||
|
||||
def _load_from_ide_db() -> tuple[str | None, str | None, float]:
|
||||
"""Extract ``(access_token, refresh_token, expires_at)`` from the IDE SQLite DB."""
|
||||
import base64 # noqa: PLC0415
|
||||
import sqlite3 # noqa: PLC0415
|
||||
|
||||
for db_path in (_IDE_STATE_DB_MAC, _IDE_STATE_DB_LINUX):
|
||||
if not db_path.exists():
|
||||
continue
|
||||
try:
|
||||
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
try:
|
||||
row = con.execute(
|
||||
"SELECT value FROM ItemTable WHERE key = ?",
|
||||
(_IDE_STATE_DB_KEY,),
|
||||
).fetchone()
|
||||
finally:
|
||||
con.close()
|
||||
if not row:
|
||||
continue
|
||||
|
||||
blob = base64.b64decode(row[0])
|
||||
candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
for candidate in candidates:
|
||||
try:
|
||||
padded = candidate + b"=" * (-len(candidate) % 4)
|
||||
inner = base64.urlsafe_b64decode(padded)
|
||||
except Exception:
|
||||
continue
|
||||
if not access_token:
|
||||
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
access_token = m.group(0).decode("ascii")
|
||||
if not refresh_token:
|
||||
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
refresh_token = m.group(0).decode("ascii")
|
||||
if access_token and refresh_token:
|
||||
break
|
||||
|
||||
if access_token:
|
||||
# Estimate expiry from DB mtime (IDE refreshes while running)
|
||||
mtime = db_path.stat().st_mtime
|
||||
expires_at = mtime + 3600.0
|
||||
return access_token, refresh_token, expires_at
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
|
||||
continue
|
||||
return None, None, 0.0
|
||||
|
||||
|
||||
def _do_token_refresh(refresh_token: str) -> tuple[str, float] | None:
|
||||
"""POST to Google OAuth endpoint and return ``(new_access_token, expires_at)``.
|
||||
|
||||
The client secret is sourced via ``get_antigravity_client_secret()`` (env var,
|
||||
config file, or npm package fallback). When unavailable the refresh is attempted
|
||||
without it — Google will reject it for web-app clients, but the npm fallback in
|
||||
``get_antigravity_client_secret()`` should ensure the secret is found at runtime.
|
||||
|
||||
Returns None when the HTTP request fails.
|
||||
"""
|
||||
from framework.config import get_antigravity_client_secret # noqa: PLC0415
|
||||
|
||||
client_secret = get_antigravity_client_secret()
|
||||
if not client_secret:
|
||||
logger.debug(
|
||||
"Antigravity client secret not configured — attempting refresh without it. "
|
||||
"Set ANTIGRAVITY_CLIENT_SECRET or run quickstart to configure."
|
||||
)
|
||||
|
||||
import urllib.error # noqa: PLC0415
|
||||
import urllib.parse # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
from framework.config import get_antigravity_client_id # noqa: PLC0415
|
||||
|
||||
params: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": get_antigravity_client_id(),
|
||||
}
|
||||
if client_secret:
|
||||
params["client_secret"] = client_secret
|
||||
body = urllib.parse.urlencode(params).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
|
||||
payload = json.loads(resp.read())
|
||||
access_token: str = payload["access_token"]
|
||||
expires_in: int = payload.get("expires_in", 3600)
|
||||
logger.debug("Antigravity token refreshed successfully")
|
||||
return access_token, time.time() + expires_in
|
||||
except Exception as exc:
|
||||
logger.debug("Antigravity token refresh failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message conversion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _clean_tool_name(name: str) -> str:
|
||||
"""Sanitize a tool name for the Antigravity function-calling schema."""
|
||||
name = re.sub(r"[/\s]", "_", name)
|
||||
if name and not (name[0].isalpha() or name[0] == "_"):
|
||||
name = "_" + name
|
||||
return name[:64]
|
||||
|
||||
|
||||
def _to_gemini_contents(
|
||||
messages: list[dict[str, Any]],
|
||||
thought_sigs: dict[str, str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-format messages to Gemini-style ``contents`` array."""
|
||||
# Pre-build a map tool_call_id → function_name from assistant messages.
|
||||
# Tool result messages (role="tool") only carry tool_call_id, not the name,
|
||||
# but Gemini requires functionResponse.name to match the functionCall.name.
|
||||
tc_id_to_name: dict[str, str] = {}
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
tc_id = tc.get("id")
|
||||
fn_name = tc.get("function", {}).get("name", "")
|
||||
if tc_id and fn_name:
|
||||
tc_id_to_name[tc_id] = fn_name
|
||||
|
||||
contents: list[dict[str, Any]] = []
|
||||
# Consecutive tool-result messages must be batched into one user turn.
|
||||
pending_tool_parts: list[dict[str, Any]] = []
|
||||
|
||||
def _flush_tool_results() -> None:
|
||||
if pending_tool_parts:
|
||||
contents.append({"role": "user", "parts": list(pending_tool_parts)})
|
||||
pending_tool_parts.clear()
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
continue # Handled via systemInstruction, not in contents.
|
||||
|
||||
if role == "tool":
|
||||
# OpenAI tool result → Gemini functionResponse part.
|
||||
result_str = content if isinstance(content, str) else str(content or "")
|
||||
tc_id = msg.get("tool_call_id", "")
|
||||
# Look up function name from the pre-built map; fall back to msg.name.
|
||||
fn_name = tc_id_to_name.get(tc_id) or msg.get("name", "")
|
||||
pending_tool_parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": fn_name,
|
||||
"id": tc_id,
|
||||
"response": {"content": result_str},
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
_flush_tool_results()
|
||||
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
parts: list[dict[str, Any]] = []
|
||||
|
||||
if isinstance(content, str) and content:
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if text:
|
||||
parts.append({"text": text})
|
||||
# Other block types (image_url etc.) skipped.
|
||||
|
||||
# Assistant messages may carry OpenAI-style tool_calls.
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
fn = tc.get("function", {})
|
||||
try:
|
||||
args = json.loads(fn.get("arguments", "{}") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
tc_id = tc.get("id", str(uuid.uuid4()))
|
||||
fc_part: dict[str, Any] = {
|
||||
"functionCall": {
|
||||
"name": fn.get("name", ""),
|
||||
"args": args,
|
||||
"id": tc_id,
|
||||
}
|
||||
}
|
||||
if thought_sigs:
|
||||
sig = thought_sigs.get(tc_id, "")
|
||||
if sig:
|
||||
fc_part["thoughtSignature"] = sig # part-level, not inside functionCall
|
||||
parts.append(fc_part)
|
||||
|
||||
if parts:
|
||||
contents.append({"role": gemini_role, "parts": parts})
|
||||
|
||||
_flush_tool_results()
|
||||
|
||||
# Gemini requires the first turn to be a user turn. Drop any leading
|
||||
# model messages so the API doesn't reject with a 400.
|
||||
while contents and contents[0].get("role") == "model":
|
||||
contents.pop(0)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _map_finish_reason(reason: str) -> str:
|
||||
return {"STOP": "stop", "MAX_TOKENS": "max_tokens", "OTHER": "tool_use"}.get(
|
||||
(reason or "").upper(), "stop"
|
||||
)
|
||||
|
||||
|
||||
def _parse_complete_response(raw: dict[str, Any], model: str) -> LLMResponse:
|
||||
"""Parse a non-streaming Antigravity response dict → LLMResponse."""
|
||||
payload: dict[str, Any] = raw.get("response", raw)
|
||||
candidates: list[dict[str, Any]] = payload.get("candidates", [])
|
||||
usage: dict[str, Any] = payload.get("usageMetadata", {})
|
||||
|
||||
text_parts: list[str] = []
|
||||
if candidates:
|
||||
for part in candidates[0].get("content", {}).get("parts", []):
|
||||
if "text" in part and not part.get("thought"):
|
||||
text_parts.append(part["text"])
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(text_parts),
|
||||
model=payload.get("modelVersion", model),
|
||||
input_tokens=usage.get("promptTokenCount", 0),
|
||||
output_tokens=usage.get("candidatesTokenCount", 0),
|
||||
stop_reason=_map_finish_reason(candidates[0].get("finishReason", "") if candidates else ""),
|
||||
raw_response=raw,
|
||||
)
|
||||
|
||||
|
||||
def _parse_sse_stream(
|
||||
response: Any,
|
||||
model: str,
|
||||
on_thought_signature: Callable[[str, str], None] | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Parse Antigravity SSE response line-by-line → StreamEvents.
|
||||
|
||||
Each SSE line looks like::
|
||||
|
||||
data: {"response": {"candidates": [...], "usageMetadata": {...}}, "traceId": "..."}
|
||||
"""
|
||||
accumulated = ""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = ""
|
||||
|
||||
for raw_line in response:
|
||||
line: str = raw_line.decode("utf-8", errors="replace").rstrip("\r\n")
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data_str = line[5:].strip()
|
||||
if not data_str or data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data: dict[str, Any] = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# The outer envelope is {"response": {...}, "traceId": "..."}.
|
||||
payload: dict[str, Any] = data.get("response", data)
|
||||
|
||||
usage = payload.get("usageMetadata", {})
|
||||
if usage:
|
||||
input_tokens = usage.get("promptTokenCount", input_tokens)
|
||||
output_tokens = usage.get("candidatesTokenCount", output_tokens)
|
||||
|
||||
for candidate in payload.get("candidates", []):
|
||||
fr = candidate.get("finishReason", "")
|
||||
if fr:
|
||||
finish_reason = fr
|
||||
|
||||
for part in candidate.get("content", {}).get("parts", []):
|
||||
if "text" in part and not part.get("thought"):
|
||||
delta: str = part["text"]
|
||||
accumulated += delta
|
||||
yield TextDeltaEvent(content=delta, snapshot=accumulated)
|
||||
elif "functionCall" in part:
|
||||
fc: dict[str, Any] = part["functionCall"]
|
||||
tool_use_id = fc.get("id") or str(uuid.uuid4())
|
||||
thought_sig = part.get("thoughtSignature", "") # sibling of functionCall
|
||||
if thought_sig and on_thought_signature:
|
||||
on_thought_signature(tool_use_id, thought_sig)
|
||||
args = fc.get("args", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=fc.get("name", ""),
|
||||
tool_input=args,
|
||||
)
|
||||
|
||||
if accumulated:
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(
|
||||
stop_reason=_map_finish_reason(finish_reason),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AntigravityProvider(LLMProvider):
|
||||
"""LLM provider for Google's internal Antigravity Code Assist gateway.
|
||||
|
||||
No local proxy required. Handles OAuth token refresh, Gemini-format
|
||||
request/response conversion, and SSE streaming directly.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "gemini-3-flash") -> None:
|
||||
# Strip any provider prefix ("openai/gemini-3-flash" → "gemini-3-flash").
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
self.model = model
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._refresh_token: str | None = None
|
||||
self._project_id: str = _DEFAULT_PROJECT_ID
|
||||
self._token_expires_at: float = 0.0
|
||||
self._thought_sigs: dict[str, str] = {} # tool_use_id → thoughtSignature
|
||||
|
||||
self._init_credentials()
|
||||
|
||||
# --- Credential management -------------------------------------------- #
|
||||
|
||||
def _init_credentials(self) -> None:
|
||||
"""Load credentials from the best available source."""
|
||||
access, refresh, project_id, expires_at = _load_from_json_file()
|
||||
if refresh:
|
||||
self._refresh_token = refresh
|
||||
self._project_id = project_id
|
||||
self._access_token = access
|
||||
self._token_expires_at = expires_at
|
||||
return
|
||||
|
||||
# Fall back to IDE state DB.
|
||||
access, refresh, expires_at = _load_from_ide_db()
|
||||
if access:
|
||||
self._access_token = access
|
||||
self._refresh_token = refresh
|
||||
self._token_expires_at = expires_at
|
||||
|
||||
def has_credentials(self) -> bool:
|
||||
"""Return True if any credential is available."""
|
||||
return bool(self._access_token or self._refresh_token)
|
||||
|
||||
def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing via OAuth if needed."""
|
||||
if (
|
||||
self._access_token
|
||||
and self._token_expires_at
|
||||
and time.time() < self._token_expires_at - _TOKEN_REFRESH_BUFFER_SECS
|
||||
):
|
||||
return self._access_token
|
||||
|
||||
if self._refresh_token:
|
||||
result = _do_token_refresh(self._refresh_token)
|
||||
if result:
|
||||
self._access_token, self._token_expires_at = result
|
||||
return self._access_token
|
||||
|
||||
if self._access_token:
|
||||
logger.warning("Using potentially stale Antigravity access token")
|
||||
return self._access_token
|
||||
|
||||
raise RuntimeError(
|
||||
"No valid Antigravity credentials. "
|
||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
||||
)
|
||||
|
||||
# --- Request building -------------------------------------------------- #
|
||||
|
||||
def _build_body(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool] | None,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
contents = _to_gemini_contents(messages, self._thought_sigs)
|
||||
inner: dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"maxOutputTokens": max_tokens},
|
||||
}
|
||||
if system:
|
||||
inner["systemInstruction"] = {"parts": [{"text": system}]}
|
||||
if tools:
|
||||
inner["tools"] = [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": _clean_tool_name(t.name),
|
||||
"description": t.description,
|
||||
"parameters": t.parameters
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
}
|
||||
]
|
||||
return {
|
||||
"project": self._project_id,
|
||||
"model": self.model,
|
||||
"request": inner,
|
||||
"requestType": "agent",
|
||||
"userAgent": "antigravity",
|
||||
"requestId": f"agent-{uuid.uuid4()}",
|
||||
}
|
||||
|
||||
# --- HTTP transport ---------------------------------------------------- #
|
||||
|
||||
def _post(self, body: dict[str, Any], *, streaming: bool) -> Any:
|
||||
"""POST to the Antigravity endpoint, falling back through the endpoint list."""
|
||||
import urllib.error # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
token = self._ensure_token()
|
||||
body_bytes = json.dumps(body).encode("utf-8")
|
||||
path = (
|
||||
"/v1internal:streamGenerateContent?alt=sse"
|
||||
if streaming
|
||||
else "/v1internal:generateContent"
|
||||
)
|
||||
headers = {
|
||||
**_BASE_HEADERS,
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if streaming:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for base_url in _ENDPOINTS:
|
||||
url = f"{base_url}{path}"
|
||||
req = urllib.request.Request(url, data=body_bytes, headers=headers, method="POST")
|
||||
try:
|
||||
return urllib.request.urlopen(req, timeout=120) # noqa: S310
|
||||
except urllib.error.HTTPError as exc:
|
||||
if exc.code in (401, 403) and self._refresh_token:
|
||||
# Token rejected — refresh once and retry this endpoint.
|
||||
result = _do_token_refresh(self._refresh_token)
|
||||
if result:
|
||||
self._access_token, self._token_expires_at = result
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
req2 = urllib.request.Request(
|
||||
url, data=body_bytes, headers=headers, method="POST"
|
||||
)
|
||||
try:
|
||||
return urllib.request.urlopen(req2, timeout=120) # noqa: S310
|
||||
except urllib.error.HTTPError as exc2:
|
||||
last_exc = exc2
|
||||
continue
|
||||
last_exc = exc
|
||||
continue
|
||||
elif exc.code >= 500:
|
||||
last_exc = exc
|
||||
continue
|
||||
# Include the API response body in the exception for easier debugging.
|
||||
try:
|
||||
err_body = exc.read().decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
err_body = "(unreadable)"
|
||||
raise RuntimeError(f"Antigravity HTTP {exc.code} from {url}: {err_body}") from exc
|
||||
except (urllib.error.URLError, OSError) as exc:
|
||||
last_exc = exc
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f"All Antigravity endpoints failed. Last error: {last_exc}"
|
||||
) from last_exc
|
||||
|
||||
# --- LLMProvider interface --------------------------------------------- #
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
if json_mode:
|
||||
suffix = "\n\nPlease respond with a valid JSON object."
|
||||
system = (system + suffix) if system else suffix.strip()
|
||||
|
||||
body = self._build_body(messages, system, tools, max_tokens)
|
||||
resp = self._post(body, streaming=False)
|
||||
return _parse_complete_response(json.loads(resp.read()), self.model)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
import asyncio # noqa: PLC0415
|
||||
import concurrent.futures # noqa: PLC0415
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
|
||||
|
||||
def _blocking_work() -> None:
|
||||
try:
|
||||
body = self._build_body(messages, system, tools, max_tokens)
|
||||
http_resp = self._post(body, streaming=True)
|
||||
for event in _parse_sse_stream(
|
||||
http_resp, self.model, self._thought_sigs.__setitem__
|
||||
):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, event)
|
||||
except Exception as exc:
|
||||
logger.error("Antigravity stream error: %s", exc)
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StreamErrorEvent(error=str(exc)))
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
fut = loop.run_in_executor(executor, _blocking_work)
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
await fut
|
||||
executor.shutdown(wait=False)
|
||||
@@ -9,8 +9,10 @@ See: https://docs.litellm.ai/docs/providers
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
@@ -46,7 +48,10 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
"""
|
||||
try:
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_OAUTH_TOKEN_PREFIX
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_OAUTH_BETA_HEADER,
|
||||
ANTHROPIC_OAUTH_TOKEN_PREFIX,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Could not apply litellm Anthropic OAuth patch — litellm internals may have "
|
||||
@@ -71,9 +76,27 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
# Check both authorization header and x-api-key for OAuth tokens.
|
||||
# litellm's optionally_handle_anthropic_oauth only checks headers["authorization"],
|
||||
# but hive passes OAuth tokens via api_key — so litellm puts them into x-api-key.
|
||||
# Anthropic rejects OAuth tokens in x-api-key; they must go in Authorization: Bearer.
|
||||
auth = result.get("authorization", "")
|
||||
if auth.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
|
||||
x_api_key = result.get("x-api-key", "")
|
||||
oauth_prefix = f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"
|
||||
auth_is_oauth = auth.startswith(oauth_prefix)
|
||||
key_is_oauth = x_api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
|
||||
if auth_is_oauth or key_is_oauth:
|
||||
token = x_api_key if key_is_oauth else auth.removeprefix("Bearer ").strip()
|
||||
result.pop("x-api-key", None)
|
||||
result["authorization"] = f"Bearer {token}"
|
||||
# Merge the OAuth beta header with any existing beta headers.
|
||||
existing_beta = result.get("anthropic-beta", "")
|
||||
beta_parts = (
|
||||
[b.strip() for b in existing_beta.split(",") if b.strip()] if existing_beta else []
|
||||
)
|
||||
if ANTHROPIC_OAUTH_BETA_HEADER not in beta_parts:
|
||||
beta_parts.append(ANTHROPIC_OAUTH_BETA_HEADER)
|
||||
result["anthropic-beta"] = ",".join(beta_parts)
|
||||
return result
|
||||
|
||||
AnthropicModelInfo.validate_environment = _patched_validate_environment
|
||||
@@ -132,6 +155,9 @@ def _patch_litellm_metadata_nonetype() -> None:
|
||||
if litellm is not None:
|
||||
_patch_litellm_anthropic_oauth()
|
||||
_patch_litellm_metadata_nonetype()
|
||||
# Let litellm silently drop params unsupported by the target provider
|
||||
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
|
||||
litellm.drop_params = True
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
@@ -162,6 +188,53 @@ def _model_supports_cache_control(model: str) -> bool:
|
||||
# enforces a coding-agent whitelist that blocks unknown User-Agents.
|
||||
KIMI_API_BASE = "https://api.kimi.com/coding"
|
||||
|
||||
# Claude Code OAuth subscription: the Anthropic API requires a specific
|
||||
# User-Agent and a billing integrity header for OAuth-authenticated requests.
|
||||
CLAUDE_CODE_VERSION = "2.1.76"
|
||||
CLAUDE_CODE_USER_AGENT = f"claude-code/{CLAUDE_CODE_VERSION}"
|
||||
_CLAUDE_CODE_BILLING_SALT = "59cf53e54c78"
|
||||
|
||||
|
||||
def _sample_js_code_unit(text: str, idx: int) -> str:
|
||||
"""Return the character at UTF-16 code unit index *idx*, matching JS semantics."""
|
||||
encoded = text.encode("utf-16-le")
|
||||
unit_offset = idx * 2
|
||||
if unit_offset + 2 > len(encoded):
|
||||
return "0"
|
||||
code_unit = int.from_bytes(encoded[unit_offset : unit_offset + 2], "little")
|
||||
return chr(code_unit)
|
||||
|
||||
|
||||
def _claude_code_billing_header(messages: list[dict[str, Any]]) -> str:
|
||||
"""Build the billing integrity system block required by Anthropic's OAuth path."""
|
||||
# Find the first user message text
|
||||
first_text = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
first_text = content
|
||||
break
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text" and block.get("text"):
|
||||
first_text = block["text"]
|
||||
break
|
||||
if first_text:
|
||||
break
|
||||
|
||||
sampled = "".join(_sample_js_code_unit(first_text, i) for i in (4, 7, 20))
|
||||
version_hash = hashlib.sha256(
|
||||
f"{_CLAUDE_CODE_BILLING_SALT}{sampled}{CLAUDE_CODE_VERSION}".encode()
|
||||
).hexdigest()
|
||||
entrypoint = os.environ.get("CLAUDE_CODE_ENTRYPOINT", "").strip() or "cli"
|
||||
return (
|
||||
f"x-anthropic-billing-header: cc_version={CLAUDE_CODE_VERSION}.{version_hash[:3]}; "
|
||||
f"cc_entrypoint={entrypoint}; cch=00000;"
|
||||
)
|
||||
|
||||
|
||||
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
|
||||
# Conversation-structure issues are deterministic — long waits don't help.
|
||||
EMPTY_STREAM_MAX_RETRIES = 3
|
||||
@@ -441,11 +514,19 @@ class LiteLLMProvider(LLMProvider):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or self._default_api_base_for_model(_original_model)
|
||||
self.extra_kwargs = kwargs
|
||||
# Detect Claude Code OAuth subscription by checking the api_key prefix.
|
||||
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
|
||||
if self._claude_code_oauth:
|
||||
# Anthropic requires a specific User-Agent for OAuth requests.
|
||||
eh = self.extra_kwargs.setdefault("extra_headers", {})
|
||||
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
|
||||
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
|
||||
# several standard OpenAI params: max_output_tokens, stream_options.
|
||||
self._codex_backend = bool(
|
||||
self.api_base and "chatgpt.com/backend-api/codex" in self.api_base
|
||||
)
|
||||
# Antigravity routes through a local OpenAI-compatible proxy — no patches needed.
|
||||
self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base)
|
||||
|
||||
if litellm is None:
|
||||
raise ImportError(
|
||||
@@ -808,6 +889,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -869,6 +953,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
}
|
||||
|
||||
def _is_anthropic_model(self) -> bool:
|
||||
"""Return True when the configured model targets Anthropic."""
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("anthropic/") or model.startswith("claude-")
|
||||
|
||||
def _is_minimax_model(self) -> bool:
|
||||
"""Return True when the configured model targets MiniMax."""
|
||||
model = (self.model or "").lower()
|
||||
@@ -1479,6 +1568,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
return
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -1516,9 +1608,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
|
||||
# Only include it for providers that support it.
|
||||
if not self._is_anthropic_model():
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
|
||||
@@ -45,6 +45,7 @@ class ToolResult:
|
||||
tool_use_id: str
|
||||
content: str
|
||||
is_error: bool = False
|
||||
is_skill_content: bool = False # AS-10: marks activated skill body, protected from pruning
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
|
||||
@@ -208,7 +208,12 @@ def configure_logging(
|
||||
|
||||
# Suppress noisy LiteLLM INFO logs (model/provider line + Provider List URL
|
||||
# printed on every single completion call). Warnings and errors still show.
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
# Honour LITELLM_LOG env var so users can opt-in to debug output.
|
||||
_litellm_level = os.getenv("LITELLM_LOG", "").upper()
|
||||
if _litellm_level and hasattr(logging, _litellm_level):
|
||||
logging.getLogger("LiteLLM").setLevel(getattr(logging, _litellm_level))
|
||||
else:
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
|
||||
# When in JSON mode, configure known third-party loggers to use JSON formatter
|
||||
# This ensures libraries like LiteLLM, httpcore also output clean JSON
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""MCP Client for connecting to Model Context Protocol servers.
|
||||
|
||||
This module provides a client for connecting to MCP servers and invoking their tools.
|
||||
Supports both STDIO and HTTP transports using the official MCP Python SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP Python SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -22,7 +22,7 @@ class MCPServerConfig:
|
||||
"""Configuration for an MCP server connection."""
|
||||
|
||||
name: str
|
||||
transport: Literal["stdio", "http"]
|
||||
transport: Literal["stdio", "http", "unix", "sse"]
|
||||
|
||||
# For STDIO transport
|
||||
command: str | None = None
|
||||
@@ -33,6 +33,7 @@ class MCPServerConfig:
|
||||
# For HTTP transport
|
||||
url: str | None = None
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
socket_path: str | None = None
|
||||
|
||||
# Optional metadata
|
||||
description: str = ""
|
||||
@@ -52,7 +53,7 @@ class MCPClient:
|
||||
"""
|
||||
Client for communicating with MCP servers.
|
||||
|
||||
Supports both STDIO and HTTP transports using the official MCP SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP SDK.
|
||||
Manages the connection lifecycle and provides methods to list and invoke tools.
|
||||
"""
|
||||
|
||||
@@ -68,6 +69,7 @@ class MCPClient:
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._stdio_context = None # Context manager for stdio_client
|
||||
self._sse_context = None # Context manager for sse_client
|
||||
self._errlog_handle = None # Track errlog file handle for cleanup
|
||||
self._http_client: httpx.Client | None = None
|
||||
self._tools: dict[str, MCPTool] = {}
|
||||
@@ -141,6 +143,10 @@ class MCPClient:
|
||||
self._connect_stdio()
|
||||
elif self.config.transport == "http":
|
||||
self._connect_http()
|
||||
elif self.config.transport == "unix":
|
||||
self._connect_unix()
|
||||
elif self.config.transport == "sse":
|
||||
self._connect_sse()
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport: {self.config.transport}")
|
||||
|
||||
@@ -266,10 +272,94 @@ class MCPClient:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_unix(self) -> None:
|
||||
"""Connect to MCP server via UNIX domain socket transport."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for UNIX transport")
|
||||
if not self.config.socket_path:
|
||||
raise ValueError("socket_path is required for UNIX transport")
|
||||
|
||||
self._http_client = httpx.Client(
|
||||
base_url=self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
transport=httpx.HTTPTransport(uds=self.config.socket_path),
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
"Connected to MCP server '%s' via UNIX socket at %s",
|
||||
self.config.name,
|
||||
self.config.socket_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_sse(self) -> None:
|
||||
"""Connect to MCP server via SSE transport using MCP SDK with persistent session."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for SSE transport")
|
||||
|
||||
try:
|
||||
loop_started = threading.Event()
|
||||
connection_ready = threading.Event()
|
||||
connection_error = []
|
||||
|
||||
def run_event_loop():
|
||||
"""Run event loop in background thread."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
loop_started.set()
|
||||
|
||||
async def init_connection():
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
self._sse_context = sse_client(
|
||||
self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
) = await self._sse_context.__aenter__()
|
||||
|
||||
self._session = ClientSession(self._read_stream, self._write_stream)
|
||||
await self._session.__aenter__()
|
||||
await self._session.initialize()
|
||||
|
||||
connection_ready.set()
|
||||
except Exception as e:
|
||||
connection_error.append(e)
|
||||
connection_ready.set()
|
||||
|
||||
self._loop.create_task(init_connection())
|
||||
self._loop.run_forever()
|
||||
|
||||
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
|
||||
self._loop_thread.start()
|
||||
|
||||
loop_started.wait(timeout=5)
|
||||
if not loop_started.is_set():
|
||||
raise RuntimeError("Event loop failed to start")
|
||||
|
||||
connection_ready.wait(timeout=10)
|
||||
if connection_error:
|
||||
raise connection_error[0]
|
||||
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via SSE")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
def _discover_tools(self) -> None:
|
||||
"""Discover available tools from the MCP server."""
|
||||
try:
|
||||
if self.config.transport == "stdio":
|
||||
if self.config.transport in {"stdio", "sse"}:
|
||||
tools_list = self._run_async(self._list_tools_stdio_async())
|
||||
else:
|
||||
tools_list = self._list_tools_http()
|
||||
@@ -371,9 +461,37 @@ class MCPClient:
|
||||
if self.config.transport == "stdio":
|
||||
with self._stdio_call_lock:
|
||||
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
elif self.config.transport == "sse":
|
||||
return self._call_tool_with_retry(
|
||||
lambda: self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
)
|
||||
elif self.config.transport == "unix":
|
||||
return self._call_tool_with_retry(lambda: self._call_tool_http(tool_name, arguments))
|
||||
else:
|
||||
return self._call_tool_http(tool_name, arguments)
|
||||
|
||||
def _call_tool_with_retry(self, call: Any) -> Any:
|
||||
"""Retry transient MCP transport failures once after reconnecting."""
|
||||
if self.config.transport == "stdio":
|
||||
return call()
|
||||
|
||||
if self.config.transport not in {"unix", "sse"}:
|
||||
return call()
|
||||
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as original_error:
|
||||
logger.warning(
|
||||
"Retrying MCP tool call after transport error from '%s': %s",
|
||||
self.config.name,
|
||||
original_error,
|
||||
)
|
||||
self._reconnect()
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as retry_error:
|
||||
raise original_error from retry_error
|
||||
|
||||
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call tool via STDIO protocol using persistent session."""
|
||||
if not self._session:
|
||||
@@ -433,18 +551,24 @@ class MCPClient:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
"""Reconnect to the configured MCP server."""
|
||||
logger.info(f"Reconnecting to MCP server '{self.config.name}'...")
|
||||
self.disconnect()
|
||||
self.connect()
|
||||
|
||||
_CLEANUP_TIMEOUT = 10
|
||||
_THREAD_JOIN_TIMEOUT = 12
|
||||
|
||||
async def _cleanup_stdio_async(self) -> None:
|
||||
"""Async cleanup for STDIO session and context managers.
|
||||
"""Async cleanup for persistent MCP session and context managers.
|
||||
|
||||
Cleanup order is critical:
|
||||
- The session must be closed BEFORE the stdio_context because the session
|
||||
depends on the streams provided by stdio_context.
|
||||
- This mirrors the initialization order in _connect_stdio(), where
|
||||
stdio_context is entered first (providing streams), then the session is
|
||||
created with those streams and entered.
|
||||
- The session must be closed BEFORE the transport context manager because the
|
||||
session depends on the streams provided by that context.
|
||||
- This mirrors the initialization order in _connect_stdio() / _connect_sse(),
|
||||
where the transport context is entered first (providing streams), then the
|
||||
session is created with those streams and entered.
|
||||
- Do not change this ordering without carefully considering these dependencies.
|
||||
"""
|
||||
# First: close session (depends on stdio_context streams)
|
||||
@@ -477,6 +601,16 @@ class MCPClient:
|
||||
finally:
|
||||
self._stdio_context = None
|
||||
|
||||
try:
|
||||
if self._sse_context:
|
||||
await self._sse_context.__aexit__(None, None, None)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("SSE context cleanup was cancelled; proceeding with best-effort shutdown")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing SSE context: {e}")
|
||||
finally:
|
||||
self._sse_context = None
|
||||
|
||||
# Third: close errlog file handle if we opened one
|
||||
if self._errlog_handle is not None:
|
||||
try:
|
||||
@@ -552,6 +686,7 @@ class MCPClient:
|
||||
# Setting None to None is safe and ensures clean state.
|
||||
self._session = None
|
||||
self._stdio_context = None
|
||||
self._sse_context = None
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._loop = None
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Shared MCP client connection management."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPConnectionManager:
|
||||
"""Process-wide MCP client pool keyed by server name."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool: dict[str, MCPClient] = {}
|
||||
self._refcounts: dict[str, int] = {}
|
||||
self._configs: dict[str, MCPServerConfig] = {}
|
||||
self._pool_lock = threading.Lock()
|
||||
# Transition events keep callers from racing a connect/reconnect/disconnect.
|
||||
self._transitions: dict[str, threading.Event] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPConnectionManager":
|
||||
"""Return the process-level singleton instance."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
def _is_connected(client: MCPClient | None) -> bool:
|
||||
return bool(client and getattr(client, "_connected", False))
|
||||
|
||||
def acquire(self, config: MCPServerConfig) -> MCPClient:
|
||||
"""Get or create a shared connection and increment its refcount."""
|
||||
server_name = config.name
|
||||
|
||||
while True:
|
||||
should_connect = False
|
||||
transition_event: threading.Event | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
client = self._pool.get(server_name)
|
||||
if self._is_connected(client) and server_name not in self._transitions:
|
||||
new_refcount = self._refcounts.get(server_name, 0) + 1
|
||||
self._refcounts[server_name] = new_refcount
|
||||
self._configs[server_name] = config
|
||||
logger.debug(
|
||||
"Reusing pooled connection for MCP server '%s' (refcount=%d)",
|
||||
server_name,
|
||||
new_refcount,
|
||||
)
|
||||
return client
|
||||
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
self._configs[server_name] = config
|
||||
should_connect = True
|
||||
|
||||
if not should_connect:
|
||||
transition_event.wait()
|
||||
continue
|
||||
|
||||
client = MCPClient(config)
|
||||
try:
|
||||
client.connect()
|
||||
except Exception:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
if (
|
||||
server_name not in self._pool
|
||||
and self._refcounts.get(server_name, 0) <= 0
|
||||
):
|
||||
self._configs.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool[server_name] = client
|
||||
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
|
||||
self._configs[server_name] = config
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return client
|
||||
|
||||
client.disconnect()
|
||||
|
||||
def release(self, server_name: str) -> None:
|
||||
"""Decrement refcount and disconnect when the last user releases."""
|
||||
while True:
|
||||
disconnect_client: MCPClient | None = None
|
||||
transition_event: threading.Event | None = None
|
||||
should_disconnect = False
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
refcount = self._refcounts.get(server_name, 0)
|
||||
if refcount <= 0:
|
||||
return
|
||||
if refcount > 1:
|
||||
self._refcounts[server_name] = refcount - 1
|
||||
return
|
||||
|
||||
disconnect_client = self._pool.pop(server_name, None)
|
||||
self._refcounts.pop(server_name, None)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
should_disconnect = True
|
||||
|
||||
if not should_disconnect:
|
||||
transition_event.wait()
|
||||
continue
|
||||
|
||||
try:
|
||||
if disconnect_client is not None:
|
||||
disconnect_client.disconnect()
|
||||
finally:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return
|
||||
|
||||
def health_check(self, server_name: str) -> bool:
|
||||
"""Return True when the pooled connection appears healthy."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
client = self._pool.get(server_name)
|
||||
config = self._configs.get(server_name)
|
||||
break
|
||||
|
||||
transition_event.wait()
|
||||
|
||||
if client is None or config is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if config.transport == "stdio":
|
||||
client.list_tools()
|
||||
return True
|
||||
|
||||
if not config.url:
|
||||
return False
|
||||
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"base_url": config.url,
|
||||
"headers": config.headers,
|
||||
"timeout": 5.0,
|
||||
}
|
||||
if config.transport == "unix":
|
||||
if not config.socket_path:
|
||||
return False
|
||||
client_kwargs["transport"] = httpx.HTTPTransport(uds=config.socket_path)
|
||||
|
||||
with httpx.Client(**client_kwargs) as http_client:
|
||||
response = http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def reconnect(self, server_name: str) -> MCPClient:
|
||||
"""Force a disconnect and replace the pooled client with a fresh one."""
|
||||
while True:
|
||||
transition_event: threading.Event | None = None
|
||||
old_client: MCPClient | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
config = self._configs.get(server_name)
|
||||
if config is None:
|
||||
raise KeyError(f"Unknown MCP server: {server_name}")
|
||||
old_client = self._pool.get(server_name)
|
||||
refcount = self._refcounts.get(server_name, 0)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
break
|
||||
|
||||
transition_event.wait()
|
||||
|
||||
if old_client is not None:
|
||||
old_client.disconnect()
|
||||
|
||||
new_client = MCPClient(config)
|
||||
try:
|
||||
new_client.connect()
|
||||
except Exception:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool.pop(server_name, None)
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool[server_name] = new_client
|
||||
self._refcounts[server_name] = max(refcount, 1)
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return new_client
|
||||
|
||||
new_client.disconnect()
|
||||
return self.acquire(config)
|
||||
|
||||
def cleanup_all(self) -> None:
|
||||
"""Disconnect all pooled clients and clear manager state."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
if self._transitions:
|
||||
pending = list(self._transitions.values())
|
||||
else:
|
||||
cleanup_events = {name: threading.Event() for name in self._pool}
|
||||
clients = list(self._pool.items())
|
||||
self._transitions.update(cleanup_events)
|
||||
self._pool.clear()
|
||||
self._refcounts.clear()
|
||||
self._configs.clear()
|
||||
break
|
||||
|
||||
for event in pending:
|
||||
event.wait()
|
||||
|
||||
for _server_name, client in clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self._pool_lock:
|
||||
for server_name, event in cleanup_events.items():
|
||||
current = self._transitions.get(server_name)
|
||||
if current is event:
|
||||
self._transitions.pop(server_name, None)
|
||||
event.set()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Pre-load validation for agent graphs.
|
||||
|
||||
Runs structural and credential checks before MCP servers are spawned.
|
||||
Runs structural, credential, and skill-trust checks before MCP servers are spawned.
|
||||
Fails fast with actionable error messages.
|
||||
"""
|
||||
|
||||
@@ -169,6 +169,9 @@ def run_preload_validation(
|
||||
1. Graph structure (includes GCU subagent-only checks) — non-recoverable
|
||||
2. Credentials — potentially recoverable via interactive setup
|
||||
|
||||
Skill discovery and trust gating (AS-13) happen later in runner._setup()
|
||||
so they have access to agent-level skill configuration.
|
||||
|
||||
Raises PreloadValidationError for structural issues.
|
||||
Raises CredentialError for credential issues.
|
||||
"""
|
||||
|
||||
@@ -552,6 +552,319 @@ def get_kimi_code_token() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Antigravity subscription token helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Antigravity IDE (native macOS/Linux app) stores OAuth tokens in its
|
||||
# VSCode-style SQLite state database under the key
|
||||
# "antigravityUnifiedStateSync.oauthToken" as a base64-encoded protobuf blob.
|
||||
ANTIGRAVITY_IDE_STATE_DB = (
|
||||
Path.home()
|
||||
/ "Library"
|
||||
/ "Application Support"
|
||||
/ "Antigravity"
|
||||
/ "User"
|
||||
/ "globalStorage"
|
||||
/ "state.vscdb"
|
||||
)
|
||||
# Linux fallback for the IDE state DB
|
||||
ANTIGRAVITY_IDE_STATE_DB_LINUX = (
|
||||
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
|
||||
)
|
||||
# Antigravity credentials stored by native OAuth implementation
|
||||
ANTIGRAVITY_AUTH_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
|
||||
ANTIGRAVITY_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_ANTIGRAVITY_TOKEN_LIFETIME_SECS = 3600 # Google access tokens expire in 1 hour
|
||||
_ANTIGRAVITY_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
|
||||
|
||||
|
||||
def _read_antigravity_ide_credentials() -> dict | None:
|
||||
"""Read credentials from the Antigravity IDE's SQLite state database.
|
||||
|
||||
The Antigravity desktop IDE (VSCode-based) stores its OAuth token as a
|
||||
base64-encoded protobuf blob in a SQLite database. The access token is
|
||||
a standard Google OAuth ``ya29.*`` bearer token.
|
||||
|
||||
Returns:
|
||||
Dict with ``accessToken`` and optionally ``refreshToken`` keys,
|
||||
plus ``_source: "ide"`` to skip file-based save on refresh.
|
||||
Returns None if the database is absent or the key is not found.
|
||||
"""
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
for db_path in (ANTIGRAVITY_IDE_STATE_DB, ANTIGRAVITY_IDE_STATE_DB_LINUX):
|
||||
if not db_path.exists():
|
||||
continue
|
||||
try:
|
||||
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
try:
|
||||
row = con.execute(
|
||||
"SELECT value FROM ItemTable WHERE key = ?",
|
||||
(_ANTIGRAVITY_IDE_STATE_DB_KEY,),
|
||||
).fetchone()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
if not row:
|
||||
continue
|
||||
|
||||
import base64
|
||||
|
||||
blob = base64.b64decode(row[0])
|
||||
|
||||
# The protobuf blob contains the access token (ya29.*) and
|
||||
# refresh token (1//*) as length-prefixed UTF-8 strings.
|
||||
# Decode the inner base64 layer and extract with regex.
|
||||
inner_b64_candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
for candidate in inner_b64_candidates:
|
||||
try:
|
||||
padded = candidate + b"=" * (-len(candidate) % 4)
|
||||
inner = base64.urlsafe_b64decode(padded)
|
||||
except Exception:
|
||||
continue
|
||||
if not access_token:
|
||||
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
access_token = m.group(0).decode("ascii")
|
||||
if not refresh_token:
|
||||
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
refresh_token = m.group(0).decode("ascii")
|
||||
if access_token and refresh_token:
|
||||
break
|
||||
|
||||
if access_token:
|
||||
return {
|
||||
"accounts": [
|
||||
{
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token or "",
|
||||
}
|
||||
],
|
||||
"_source": "ide",
|
||||
"_db_path": str(db_path),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _read_antigravity_credentials() -> dict | None:
|
||||
"""Read Antigravity auth data from all supported credential sources.
|
||||
|
||||
Checks in order:
|
||||
1. Antigravity IDE SQLite state database (native macOS/Linux app)
|
||||
2. Native OAuth credentials file (~/.hive/antigravity-accounts.json)
|
||||
|
||||
Returns:
|
||||
Auth data dict with an ``accounts`` list on success, None otherwise.
|
||||
"""
|
||||
# 1. Native Antigravity IDE (primary on macOS)
|
||||
ide_creds = _read_antigravity_ide_credentials()
|
||||
if ide_creds:
|
||||
return ide_creds
|
||||
|
||||
# 2. Native OAuth credentials file
|
||||
if ANTIGRAVITY_AUTH_FILE.exists():
|
||||
try:
|
||||
with open(ANTIGRAVITY_AUTH_FILE, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
accounts = data.get("accounts", [])
|
||||
if accounts and isinstance(accounts[0], dict):
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _is_antigravity_token_expired(auth_data: dict) -> bool:
|
||||
"""Check whether the Antigravity access token is expired or near expiry.
|
||||
|
||||
For IDE-sourced credentials: uses the state DB's mtime as last_refresh
|
||||
since the IDE keeps the DB fresh while it's running.
|
||||
For JSON-sourced credentials: uses the ``last_refresh`` field or file mtime.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
now = time.time()
|
||||
|
||||
if auth_data.get("_source") == "ide":
|
||||
# The IDE refreshes tokens automatically while running.
|
||||
# Use the DB file's mtime as a proxy for when the token was last updated.
|
||||
try:
|
||||
db_path = Path(auth_data.get("_db_path", str(ANTIGRAVITY_IDE_STATE_DB)))
|
||||
last_refresh: float = db_path.stat().st_mtime
|
||||
except OSError:
|
||||
return True
|
||||
expires_at = last_refresh + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
|
||||
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
|
||||
|
||||
last_refresh_val: float | str | None = auth_data.get("last_refresh")
|
||||
if last_refresh_val is None:
|
||||
try:
|
||||
last_refresh_val = ANTIGRAVITY_AUTH_FILE.stat().st_mtime
|
||||
except OSError:
|
||||
return True
|
||||
elif isinstance(last_refresh_val, str):
|
||||
try:
|
||||
last_refresh_val = datetime.fromisoformat(
|
||||
last_refresh_val.replace("Z", "+00:00")
|
||||
).timestamp()
|
||||
except (ValueError, TypeError):
|
||||
return True
|
||||
|
||||
expires_at = float(last_refresh_val) + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
|
||||
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
|
||||
|
||||
|
||||
def _refresh_antigravity_token(refresh_token: str) -> dict | None:
|
||||
"""Refresh the Antigravity access token via Google OAuth.
|
||||
|
||||
POSTs form-encoded ``grant_type=refresh_token`` to the Google token
|
||||
endpoint using Antigravity's public OAuth client ID.
|
||||
|
||||
Returns:
|
||||
Parsed response dict (containing ``access_token``) on success,
|
||||
None on any error.
|
||||
"""
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
from framework.config import get_antigravity_client_id, get_antigravity_client_secret
|
||||
|
||||
client_id = get_antigravity_client_id()
|
||||
client_secret = get_antigravity_client_secret()
|
||||
params: dict = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
params["client_secret"] = client_secret
|
||||
|
||||
data = urllib.parse.urlencode(params).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
ANTIGRAVITY_OAUTH_TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
|
||||
return json.loads(resp.read())
|
||||
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
|
||||
logger.debug("Antigravity token refresh failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _save_refreshed_antigravity_credentials(auth_data: dict, token_data: dict) -> None:
|
||||
"""Write refreshed tokens back to the Antigravity JSON credentials file.
|
||||
|
||||
Skipped for IDE-sourced credentials (the IDE manages its own DB).
|
||||
Updates ``accounts[0].accessToken`` (and ``refreshToken`` if present),
|
||||
then persists ``last_refresh`` as an ISO-8601 UTC string.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# IDE manages its own state — we do not write back to its SQLite DB
|
||||
if auth_data.get("_source") == "ide":
|
||||
return
|
||||
|
||||
try:
|
||||
accounts = auth_data.get("accounts", [])
|
||||
if not accounts:
|
||||
return
|
||||
account = accounts[0]
|
||||
account["accessToken"] = token_data["access_token"]
|
||||
if "refresh_token" in token_data:
|
||||
account["refreshToken"] = token_data["refresh_token"]
|
||||
auth_data["accounts"] = accounts
|
||||
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
|
||||
|
||||
ANTIGRAVITY_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = os.open(ANTIGRAVITY_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(auth_data, f, indent=2)
|
||||
logger.debug("Antigravity credentials refreshed and saved")
|
||||
except (OSError, KeyError) as exc:
|
||||
logger.debug("Failed to save refreshed Antigravity credentials: %s", exc)
|
||||
|
||||
|
||||
def get_antigravity_token() -> str | None:
|
||||
"""Get the OAuth access token from an Antigravity subscription.
|
||||
|
||||
Credential sources checked in order:
|
||||
1. Antigravity IDE SQLite state DB (native app, macOS/Linux)
|
||||
2. antigravity-auth CLI JSON file
|
||||
|
||||
For IDE credentials the token is read directly (the IDE refreshes it
|
||||
automatically while running). For JSON credentials an automatic OAuth
|
||||
refresh is attempted when the token is near expiry.
|
||||
|
||||
Returns:
|
||||
The ``ya29.*`` Google OAuth access token, or None if unavailable.
|
||||
"""
|
||||
auth_data = _read_antigravity_credentials()
|
||||
if not auth_data:
|
||||
return None
|
||||
|
||||
accounts = auth_data.get("accounts", [])
|
||||
if not accounts:
|
||||
return None
|
||||
account = accounts[0]
|
||||
|
||||
access_token = account.get("accessToken")
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
if not _is_antigravity_token_expired(auth_data):
|
||||
return access_token
|
||||
|
||||
# Token is expired or near expiry — attempt a refresh
|
||||
refresh_token = account.get("refreshToken")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"Antigravity token expired and no refresh token available. "
|
||||
"Re-open the Antigravity IDE to refresh, or run 'antigravity-auth accounts add'."
|
||||
)
|
||||
return access_token # return stale token; proxy may still accept it briefly
|
||||
|
||||
logger.info("Antigravity token expired or near expiry, refreshing...")
|
||||
token_data = _refresh_antigravity_token(refresh_token)
|
||||
|
||||
if token_data and "access_token" in token_data:
|
||||
_save_refreshed_antigravity_credentials(auth_data, token_data)
|
||||
return token_data["access_token"]
|
||||
|
||||
logger.warning(
|
||||
"Antigravity token refresh failed. "
|
||||
"Re-open the Antigravity IDE or run 'antigravity-auth accounts add'."
|
||||
)
|
||||
return access_token
|
||||
|
||||
|
||||
def _is_antigravity_proxy_available() -> bool:
|
||||
"""Return True if antigravity-auth serve is running on localhost:8069."""
|
||||
import socket
|
||||
|
||||
try:
|
||||
with socket.create_connection(("localhost", 8069), timeout=0.5):
|
||||
return True
|
||||
except (OSError, TimeoutError):
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about an exported agent."""
|
||||
@@ -1141,7 +1454,10 @@ class AgentRunner:
|
||||
|
||||
# Create LLM provider
|
||||
# Uses LiteLLM which auto-detects the provider from model name
|
||||
if self.mock_mode:
|
||||
# Skip if already injected (e.g. worker agents with a pre-built LLM)
|
||||
if self._llm is not None:
|
||||
pass # LLM already configured externally
|
||||
elif self.mock_mode:
|
||||
# Use mock LLM for testing without real API calls
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
@@ -1155,6 +1471,7 @@ class AgentRunner:
|
||||
use_claude_code = llm_config.get("use_claude_code_subscription", False)
|
||||
use_codex = llm_config.get("use_codex_subscription", False)
|
||||
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
|
||||
use_antigravity = llm_config.get("use_antigravity_subscription", False)
|
||||
api_base = llm_config.get("api_base")
|
||||
|
||||
api_key = None
|
||||
@@ -1176,6 +1493,8 @@ class AgentRunner:
|
||||
if not api_key:
|
||||
print("Warning: Kimi Code subscription configured but no key found.")
|
||||
print("Run 'kimi /login' to authenticate, then try again.")
|
||||
elif use_antigravity:
|
||||
pass # AntigravityProvider handles credentials internally
|
||||
|
||||
if api_key and use_claude_code:
|
||||
# Use litellm's built-in Anthropic OAuth support.
|
||||
@@ -1214,6 +1533,19 @@ class AgentRunner:
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
elif use_antigravity:
|
||||
# Direct OAuth to Google's internal Cloud Code Assist gateway.
|
||||
# No local proxy required — AntigravityProvider handles token
|
||||
# refresh and Gemini-format request/response conversion natively.
|
||||
from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415
|
||||
|
||||
provider = AntigravityProvider(model=self.model)
|
||||
if not provider.has_credentials():
|
||||
print(
|
||||
"Warning: Antigravity credentials not found. "
|
||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
||||
)
|
||||
self._llm = provider
|
||||
else:
|
||||
# Local models (e.g. Ollama) don't need an API key
|
||||
if self._is_local_model(self.model):
|
||||
@@ -1340,7 +1672,7 @@ class AgentRunner:
|
||||
except Exception:
|
||||
pass # Best-effort — agent works without account info
|
||||
|
||||
# Skill configuration — the runtime handles discovery, loading, and
|
||||
# Skill configuration — the runtime handles discovery, loading, trust-gating and
|
||||
# prompt rasterization. The runner just builds the config.
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
@@ -1351,6 +1683,7 @@ class AgentRunner:
|
||||
skills=getattr(self, "_agent_skills", None),
|
||||
),
|
||||
project_root=self.agent_path,
|
||||
interactive=self._interactive,
|
||||
)
|
||||
|
||||
self._setup_agent_runtime(
|
||||
@@ -1462,6 +1795,9 @@ class AgentRunner:
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus=None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
skills_manager_config=None,
|
||||
) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
|
||||
@@ -54,6 +54,8 @@ class ToolRegistry:
|
||||
def __init__(self):
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
self._mcp_clients: list[Any] = [] # List of MCPClient instances
|
||||
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
|
||||
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
|
||||
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
|
||||
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
|
||||
# MCP resync tracking
|
||||
@@ -480,6 +482,7 @@ class ToolRegistry:
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
use_connection_manager: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
Register an MCP server and discover its tools.
|
||||
@@ -495,12 +498,14 @@ class ToolRegistry:
|
||||
- url: Server URL (for http)
|
||||
- headers: HTTP headers (for http)
|
||||
- description: Server description (optional)
|
||||
use_connection_manager: When True, reuse a shared client keyed by server name
|
||||
|
||||
Returns:
|
||||
Number of tools registered from this server
|
||||
"""
|
||||
try:
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
# Build config object
|
||||
config = MCPServerConfig(
|
||||
@@ -516,11 +521,18 @@ class ToolRegistry:
|
||||
)
|
||||
|
||||
# Create and connect client
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
if use_connection_manager:
|
||||
client = MCPConnectionManager.get_instance().acquire(config)
|
||||
else:
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
|
||||
# Store client for cleanup
|
||||
self._mcp_clients.append(client)
|
||||
client_id = id(client)
|
||||
self._mcp_client_servers[client_id] = config.name
|
||||
if use_connection_manager:
|
||||
self._mcp_managed_clients.add(client_id)
|
||||
|
||||
# Register each tool
|
||||
server_name = server_config["name"]
|
||||
@@ -720,12 +732,7 @@ class ToolRegistry:
|
||||
logger.info("%s — resyncing MCP servers", reason)
|
||||
|
||||
# 1. Disconnect existing MCP clients
|
||||
for client in self._mcp_clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client during resync: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._cleanup_mcp_clients("during resync")
|
||||
|
||||
# 2. Remove MCP-registered tools
|
||||
for name in self._mcp_tool_names:
|
||||
@@ -740,12 +747,28 @@ class ToolRegistry:
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all MCP client connections."""
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
def _cleanup_mcp_clients(self, context: str = "") -> None:
|
||||
"""Disconnect or release all tracked MCP clients for this registry."""
|
||||
if context:
|
||||
context = f" {context}"
|
||||
|
||||
for client in self._mcp_clients:
|
||||
client_id = id(client)
|
||||
server_name = self._mcp_client_servers.get(client_id, client.config.name)
|
||||
try:
|
||||
client.disconnect()
|
||||
if client_id in self._mcp_managed_clients:
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
MCPConnectionManager.get_instance().release(server_name)
|
||||
else:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client: {e}")
|
||||
logger.warning(f"Error disconnecting MCP client{context}: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._mcp_client_servers.clear()
|
||||
self._mcp_managed_clients.clear()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup."""
|
||||
|
||||
@@ -137,6 +137,7 @@ class AgentRuntime:
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent runtime.
|
||||
@@ -158,6 +159,9 @@ class AgentRuntime:
|
||||
event_bus: Optional external EventBus. If provided, the runtime shares
|
||||
this bus instead of creating its own. Used by SessionManager to
|
||||
share a single bus between queen, worker, and judge.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
@@ -195,6 +199,8 @@ class AgentRuntime:
|
||||
self._skills_manager = SkillsManager()
|
||||
self._skills_manager.load()
|
||||
|
||||
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
|
||||
|
||||
# Primary graph identity
|
||||
self._graph_id: str = graph_id or "primary"
|
||||
|
||||
@@ -341,6 +347,7 @@ class AgentRuntime:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
@@ -977,6 +984,7 @@ class AgentRuntime:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
if self._running:
|
||||
await stream.start()
|
||||
@@ -1760,6 +1768,7 @@ def create_agent_runtime(
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
) -> AgentRuntime:
|
||||
"""
|
||||
Create and configure an AgentRuntime with entry points.
|
||||
@@ -1786,6 +1795,9 @@ def create_agent_runtime(
|
||||
accounts_data: Raw account data for per-node prompt generation.
|
||||
tool_provider_map: Tool name to provider name mapping for account routing.
|
||||
event_bus: Optional external EventBus to share with other components.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt.
|
||||
protocols_prompt: Default skill operational protocols for system prompt.
|
||||
skill_dirs: Skill base directories for Tier 3 resource access.
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
@@ -1819,6 +1831,7 @@ def create_agent_runtime(
|
||||
skills_manager_config=skills_manager_config,
|
||||
skills_catalog_prompt=skills_catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
skill_dirs=skill_dirs,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
|
||||
@@ -117,6 +117,7 @@ class EventType(StrEnum):
|
||||
|
||||
# Context management
|
||||
CONTEXT_COMPACTED = "context_compacted"
|
||||
CONTEXT_USAGE_UPDATED = "context_usage_updated"
|
||||
|
||||
# External triggers
|
||||
WEBHOOK_RECEIVED = "webhook_received"
|
||||
|
||||
@@ -188,6 +188,7 @@ class ExecutionStream:
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -213,6 +214,7 @@ class ExecutionStream:
|
||||
tool_provider_map: Tool name to provider name mapping for account routing
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
@@ -236,6 +238,7 @@ class ExecutionStream:
|
||||
self._tool_provider_map = tool_provider_map
|
||||
self._skills_catalog_prompt = skills_catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
self._skill_dirs: list[str] = skill_dirs or []
|
||||
|
||||
_es_logger = logging.getLogger(__name__)
|
||||
if protocols_prompt:
|
||||
@@ -696,6 +699,7 @@ class ExecutionStream:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self._skills_catalog_prompt,
|
||||
protocols_prompt=self._protocols_prompt,
|
||||
skill_dirs=self._skill_dirs,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
@@ -8,6 +8,7 @@ write. Errors are silently swallowed — this must never break the agent.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
@@ -47,6 +48,9 @@ def log_llm_turn(
|
||||
Never raises.
|
||||
"""
|
||||
try:
|
||||
# Skip logging during test runs to avoid polluting real logs.
|
||||
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("HIVE_DISABLE_LLM_LOGS"):
|
||||
return
|
||||
global _log_file, _log_ready # noqa: PLW0603
|
||||
if not _log_ready:
|
||||
_log_file = _open_log()
|
||||
|
||||
@@ -37,6 +37,7 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.NODE_RETRY,
|
||||
EventType.NODE_TOOL_DOOM_LOOP,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.CONTEXT_USAGE_UPDATED,
|
||||
EventType.WORKER_LOADED,
|
||||
EventType.CREDENTIALS_REQUIRED,
|
||||
EventType.SUBAGENT_REPORT,
|
||||
|
||||
@@ -96,8 +96,7 @@ class SessionManager:
|
||||
|
||||
Internal helper — use create_session() or create_session_with_worker().
|
||||
"""
|
||||
from framework.config import RuntimeConfig
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.config import RuntimeConfig, get_hive_config
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -111,12 +110,20 @@ class SessionManager:
|
||||
rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model)
|
||||
|
||||
# Session owns these — shared with queen and worker
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
llm_config = get_hive_config().get("llm", {})
|
||||
if llm_config.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
llm = AntigravityProvider(model=rc.model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
event_bus = EventBus()
|
||||
|
||||
session = Session(
|
||||
@@ -287,7 +294,17 @@ class SessionManager:
|
||||
try:
|
||||
# Blocking I/O — load in executor
|
||||
loop = asyncio.get_running_loop()
|
||||
resolved_model = model or self._model
|
||||
|
||||
# Prioritize: explicit model arg > worker-specific model > session default
|
||||
from framework.config import (
|
||||
get_preferred_worker_model,
|
||||
get_worker_api_base,
|
||||
get_worker_api_key,
|
||||
get_worker_llm_extra_kwargs,
|
||||
)
|
||||
|
||||
worker_model = get_preferred_worker_model()
|
||||
resolved_model = model or worker_model or self._model
|
||||
runner = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: AgentRunner.load(
|
||||
@@ -299,6 +316,30 @@ class SessionManager:
|
||||
),
|
||||
)
|
||||
|
||||
# If a worker-specific model is configured, build an LLM provider
|
||||
# with the correct worker credentials so _setup() doesn't fall back
|
||||
# to the queen's llm config (which may be a different provider).
|
||||
if worker_model and not model:
|
||||
from framework.config import get_hive_config
|
||||
|
||||
worker_llm_cfg = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm_cfg.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
runner._llm = AntigravityProvider(model=resolved_model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
worker_api_key = get_worker_api_key()
|
||||
worker_api_base = get_worker_api_base()
|
||||
worker_extra = get_worker_llm_extra_kwargs()
|
||||
runner._llm = LiteLLMProvider(
|
||||
model=resolved_model,
|
||||
api_key=worker_api_key,
|
||||
api_base=worker_api_base,
|
||||
**worker_extra,
|
||||
)
|
||||
|
||||
# Setup with session's event bus
|
||||
if runner._agent_runtime is None:
|
||||
await loop.run_in_executor(
|
||||
@@ -793,10 +834,11 @@ class SessionManager:
|
||||
exec_id = event.execution_id
|
||||
|
||||
if event.type == _ET.EXECUTION_STARTED:
|
||||
# New run on this execution_id — reset cooldown so the first
|
||||
# iteration always produces a mid-run snapshot.
|
||||
# New run on this execution_id — start the cooldown timer so
|
||||
# mid-run snapshots don't fire immediately at session start.
|
||||
# The first snapshot will happen after _DIGEST_COOLDOWN seconds.
|
||||
if exec_id:
|
||||
_last_digest.pop(exec_id, None)
|
||||
_last_digest[exec_id] = _time.monotonic()
|
||||
|
||||
elif event.type in (
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
@@ -923,6 +965,7 @@ class SessionManager:
|
||||
# then use max+1 as offset so resumed sessions produce monotonically
|
||||
# increasing iteration values — preventing frontend message ID collisions.
|
||||
iteration_offset = 0
|
||||
last_phase = ""
|
||||
events_path = queen_dir / "events.jsonl"
|
||||
try:
|
||||
if events_path.exists():
|
||||
@@ -934,17 +977,25 @@ class SessionManager:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
it = evt.get("data", {}).get("iteration")
|
||||
data = evt.get("data", {})
|
||||
it = data.get("iteration")
|
||||
if isinstance(it, int) and it > max_iter:
|
||||
max_iter = it
|
||||
# Track the latest queen phase from QUEEN_PHASE_CHANGED events
|
||||
if evt.get("type") == "queen_phase_changed":
|
||||
phase = data.get("phase")
|
||||
if phase:
|
||||
last_phase = phase
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
if max_iter >= 0:
|
||||
iteration_offset = max_iter + 1
|
||||
logger.info(
|
||||
"Session '%s' resuming with iteration_offset=%d (from events.jsonl max)",
|
||||
"Session '%s' resuming with iteration_offset=%d"
|
||||
" (from events.jsonl max), last phase: %s",
|
||||
session.id,
|
||||
iteration_offset,
|
||||
last_phase or "unknown",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
@@ -996,10 +1047,17 @@ class SessionManager:
|
||||
_consolidation_session_dir = queen_dir
|
||||
|
||||
async def _on_compaction(_event) -> None:
|
||||
# Only consolidate on queen compactions — worker and subagent
|
||||
# compactions are frequent and don't warrant a memory update.
|
||||
if getattr(_event, "stream_id", None) != "queen":
|
||||
return
|
||||
from framework.agents.queen.queen_memory import consolidate_queen_memory
|
||||
|
||||
await consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
asyncio.create_task(
|
||||
consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
),
|
||||
name=f"queen-memory-consolidation-{session.id}",
|
||||
)
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Hive Agent Skills — discovery, parsing, and injection of SKILL.md packages.
|
||||
"""Hive Agent Skills — discovery, parsing, trust gating, and injection of SKILL.md packages.
|
||||
|
||||
Implements the open Agent Skills standard (agentskills.io) for portable
|
||||
skill discovery and activation, plus built-in default skills for runtime
|
||||
operational discipline.
|
||||
operational discipline, and AS-13 trust gating for project-scope skills.
|
||||
"""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
@@ -10,7 +10,10 @@ from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.models import TrustStatus
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillError, SkillErrorCode, log_skill_error
|
||||
from framework.skills.trust import TrustedRepoStore, TrustGate
|
||||
|
||||
__all__ = [
|
||||
"DefaultSkillConfig",
|
||||
@@ -22,5 +25,11 @@ __all__ = [
|
||||
"SkillsConfig",
|
||||
"SkillsManager",
|
||||
"SkillsManagerConfig",
|
||||
"TrustGate",
|
||||
"TrustedRepoStore",
|
||||
"TrustStatus",
|
||||
"parse_skill_md",
|
||||
"SkillError",
|
||||
"SkillErrorCode",
|
||||
"log_skill_error",
|
||||
]
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -76,6 +77,7 @@ class SkillCatalog:
|
||||
lines.append(f" <name>{escape(skill.name)}</name>")
|
||||
lines.append(f" <description>{escape(skill.description)}</description>")
|
||||
lines.append(f" <location>{escape(skill.location)}</location>")
|
||||
lines.append(f" <base_dir>{escape(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</available_skills>")
|
||||
|
||||
@@ -96,7 +98,14 @@ class SkillCatalog:
|
||||
for name in skill_names:
|
||||
skill = self.get(name)
|
||||
if skill is None:
|
||||
logger.warning("Pre-activated skill '%s' not found in catalog", name)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"Pre-activated skill '{name}' not found in catalog",
|
||||
why="The skill was listed for pre-activation but was not discovered.",
|
||||
fix=f"Check that a SKILL.md for '{name}' exists in a scanned directory.",
|
||||
)
|
||||
continue
|
||||
if self.is_activated(name):
|
||||
continue # Already activated, skip duplicate
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
"""CLI commands for the Hive skill system.
|
||||
|
||||
Phase 1 commands (AS-13):
|
||||
hive skill list — list discovered skills across all scopes
|
||||
hive skill trust <path> — permanently trust a project repo's skills
|
||||
|
||||
Full CLI suite (CLI-1 through CLI-13) is Phase 2.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def register_skill_commands(subparsers) -> None:
|
||||
"""Register the ``hive skill`` subcommand group."""
|
||||
skill_parser = subparsers.add_parser("skill", help="Manage skills")
|
||||
skill_sub = skill_parser.add_subparsers(dest="skill_command", required=True)
|
||||
|
||||
# hive skill list
|
||||
list_parser = skill_sub.add_parser("list", help="List discovered skills across all scopes")
|
||||
list_parser.add_argument(
|
||||
"--project-dir",
|
||||
default=None,
|
||||
metavar="PATH",
|
||||
help="Project directory to scan (default: current directory)",
|
||||
)
|
||||
list_parser.set_defaults(func=cmd_skill_list)
|
||||
|
||||
# hive skill trust
|
||||
trust_parser = skill_sub.add_parser(
|
||||
"trust",
|
||||
help="Permanently trust a project repository so its skills load without prompting",
|
||||
)
|
||||
trust_parser.add_argument(
|
||||
"project_path",
|
||||
help="Path to the project directory (must contain a .git with a remote origin)",
|
||||
)
|
||||
trust_parser.set_defaults(func=cmd_skill_trust)
|
||||
|
||||
|
||||
def cmd_skill_list(args) -> int:
|
||||
"""List all discovered skills grouped by scope."""
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
|
||||
project_dir = Path(args.project_dir).resolve() if args.project_dir else Path.cwd()
|
||||
skills = SkillDiscovery(DiscoveryConfig(project_root=project_dir)).discover()
|
||||
|
||||
if not skills:
|
||||
print("No skills discovered.")
|
||||
return 0
|
||||
|
||||
scope_headers = {
|
||||
"project": "PROJECT SKILLS",
|
||||
"user": "USER SKILLS",
|
||||
"framework": "FRAMEWORK SKILLS",
|
||||
}
|
||||
|
||||
for scope in ("project", "user", "framework"):
|
||||
scope_skills = [s for s in skills if s.source_scope == scope]
|
||||
if not scope_skills:
|
||||
continue
|
||||
print(f"\n{scope_headers[scope]}")
|
||||
print("─" * 40)
|
||||
for skill in scope_skills:
|
||||
print(f" • {skill.name}")
|
||||
print(f" {skill.description}")
|
||||
print(f" {skill.location}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_skill_trust(args) -> int:
|
||||
"""Permanently trust a project repository's skills."""
|
||||
from framework.skills.trust import TrustedRepoStore, _normalize_remote_url
|
||||
|
||||
project_path = Path(args.project_path).resolve()
|
||||
|
||||
if not project_path.exists():
|
||||
print(f"Error: path does not exist: {project_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if not (project_path / ".git").exists():
|
||||
print(
|
||||
f"Error: {project_path} is not a git repository (no .git directory).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(project_path), "remote", "get-url", "origin"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(
|
||||
"Error: no remote 'origin' configured in this repository.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
remote_url = result.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
print("Error: git remote lookup timed out.", file=sys.stderr)
|
||||
return 1
|
||||
except (FileNotFoundError, OSError) as e:
|
||||
print(f"Error reading git remote: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo_key = _normalize_remote_url(remote_url)
|
||||
store = TrustedRepoStore()
|
||||
store.trust(repo_key, project_path=str(project_path))
|
||||
|
||||
print(f"✓ Trusted: {repo_key}")
|
||||
print(" Stored in ~/.hive/trusted_repos.json")
|
||||
print(" Skills from this repository will load without prompting in future runs.")
|
||||
return 0
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,12 +61,14 @@ class DefaultSkillManager:
|
||||
self._config = config or SkillsConfig()
|
||||
self._skills: dict[str, ParsedSkill] = {}
|
||||
self._loaded = False
|
||||
self._error_count = 0
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load all enabled default skill SKILL.md files."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
error_count = 0
|
||||
for skill_name, dir_name in SKILL_REGISTRY.items():
|
||||
if not self._config.is_default_enabled(skill_name):
|
||||
logger.info("Default skill '%s' disabled by config", skill_name)
|
||||
@@ -73,17 +76,34 @@ class DefaultSkillManager:
|
||||
|
||||
skill_path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
if not skill_path.is_file():
|
||||
logger.error("Default skill SKILL.md not found: %s", skill_path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"Default skill SKILL.md not found: '{skill_path}'",
|
||||
why=f"The framework skill '{skill_name}' is missing its SKILL.md file.",
|
||||
fix="Reinstall the hive framework — this file is part of the package.",
|
||||
)
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
parsed = parse_skill_md(skill_path, source_scope="framework")
|
||||
if parsed is None:
|
||||
logger.error("Failed to parse default skill: %s", skill_path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Failed to parse default skill '{skill_name}'",
|
||||
why=f"parse_skill_md returned None for '{skill_path}'.",
|
||||
fix="Reinstall the hive framework — this file may be corrupted.",
|
||||
)
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
self._skills[skill_name] = parsed
|
||||
|
||||
self._loaded = True
|
||||
self._error_count = error_count
|
||||
|
||||
def build_protocols_prompt(self) -> str:
|
||||
"""Build the combined operational protocols section.
|
||||
@@ -127,8 +147,23 @@ class DefaultSkillManager:
|
||||
"""Log which default skills are active and their configuration."""
|
||||
if not self._skills:
|
||||
logger.info("Default skills: all disabled")
|
||||
return
|
||||
|
||||
# DX-3: Per-skill structured startup log
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
overrides = self._config.get_default_overrides(skill_name)
|
||||
status = f"loaded overrides={overrides}" if overrides else "loaded"
|
||||
elif not self._config.is_default_enabled(skill_name):
|
||||
status = "disabled"
|
||||
else:
|
||||
status = "error"
|
||||
logger.info(
|
||||
"skill_startup name=%s scope=framework status=%s",
|
||||
skill_name,
|
||||
status,
|
||||
)
|
||||
|
||||
# Original active skills log line (preserved for backward compatibility)
|
||||
active = []
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
@@ -138,7 +173,21 @@ class DefaultSkillManager:
|
||||
else:
|
||||
active.append(skill_name)
|
||||
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
if active:
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
|
||||
# DX-3: Summary line with error count
|
||||
total = len(SKILL_REGISTRY)
|
||||
active_count = len(self._skills)
|
||||
error_count = getattr(self, "_error_count", 0)
|
||||
disabled_count = total - active_count - error_count
|
||||
logger.info(
|
||||
"Skills: %d default (%d active, %d disabled, %d error)",
|
||||
total,
|
||||
active_count,
|
||||
disabled_count,
|
||||
error_count,
|
||||
)
|
||||
|
||||
@property
|
||||
def active_skill_names(self) -> list[str]:
|
||||
|
||||
@@ -11,6 +11,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -172,11 +173,13 @@ class SkillDiscovery:
|
||||
for skill in skills:
|
||||
if skill.name in seen:
|
||||
existing = seen[skill.name]
|
||||
logger.warning(
|
||||
"Skill name collision: '%s' from %s overrides %s",
|
||||
skill.name,
|
||||
skill.location,
|
||||
existing.location,
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_COLLISION,
|
||||
what=f"Skill name collision: '{skill.name}'",
|
||||
why=f"'{skill.location}' overrides '{existing.location}'.",
|
||||
fix="Rename one of the conflicting skill directories to use a unique name.",
|
||||
)
|
||||
seen[skill.name] = skill
|
||||
|
||||
|
||||
@@ -42,11 +42,14 @@ class SkillsManagerConfig:
|
||||
When ``None``, community discovery is skipped.
|
||||
skip_community_discovery: Explicitly skip community scanning
|
||||
even when ``project_root`` is set.
|
||||
interactive: Whether trust gating can prompt the user interactively.
|
||||
When ``False``, untrusted project skills are silently skipped.
|
||||
"""
|
||||
|
||||
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
|
||||
project_root: Path | None = None
|
||||
skip_community_discovery: bool = False
|
||||
interactive: bool = True
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
@@ -63,6 +66,7 @@ class SkillsManager:
|
||||
self._loaded = False
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
self._allowlisted_dirs: list[str] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
@@ -85,6 +89,7 @@ class SkillsManager:
|
||||
mgr._loaded = True # skip load()
|
||||
mgr._catalog_prompt = skills_catalog_prompt
|
||||
mgr._protocols_prompt = protocols_prompt
|
||||
mgr._allowlisted_dirs = []
|
||||
return mgr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -113,9 +118,18 @@ class SkillsManager:
|
||||
# 1. Community skill discovery (when project_root is available)
|
||||
catalog_prompt = ""
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
from framework.skills.trust import TrustGate
|
||||
|
||||
discovery = SkillDiscovery(DiscoveryConfig(project_root=self._config.project_root))
|
||||
discovered = discovery.discover()
|
||||
|
||||
# Trust-gate project-scope skills (AS-13)
|
||||
discovered = TrustGate(interactive=self._config.interactive).filter_and_gate(
|
||||
discovered, project_dir=self._config.project_root
|
||||
)
|
||||
|
||||
catalog = SkillCatalog(discovered)
|
||||
self._allowlisted_dirs = catalog.allowlisted_dirs
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
|
||||
# Pre-activated community skills
|
||||
@@ -132,6 +146,16 @@ class SkillsManager:
|
||||
default_mgr.load()
|
||||
default_mgr.log_active_skills()
|
||||
protocols_prompt = default_mgr.build_protocols_prompt()
|
||||
# DX-3: Community skill startup summary
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
community_count = len(catalog._skills) if catalog_prompt else 0
|
||||
pre_activated_count = len(skills_config.skills) if skills_config.skills else 0
|
||||
logger.info(
|
||||
"Skills: %d community (%d catalog, %d pre-activated)",
|
||||
community_count,
|
||||
community_count,
|
||||
pre_activated_count,
|
||||
)
|
||||
|
||||
# 3. Cache
|
||||
self._catalog_prompt = catalog_prompt
|
||||
@@ -160,6 +184,11 @@ class SkillsManager:
|
||||
"""Default skill operational protocols for system prompt injection."""
|
||||
return self._protocols_prompt
|
||||
|
||||
@property
|
||||
def allowlisted_dirs(self) -> list[str]:
|
||||
"""Skill base directories for Tier 3 resource access (AS-6)."""
|
||||
return self._allowlisted_dirs
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Data models for the Hive skill system (Agent Skills standard)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SkillScope(StrEnum):
|
||||
"""Where a skill was discovered."""
|
||||
|
||||
PROJECT = "project"
|
||||
USER = "user"
|
||||
FRAMEWORK = "framework"
|
||||
|
||||
|
||||
class TrustStatus(StrEnum):
|
||||
"""Trust state of a skill entry."""
|
||||
|
||||
TRUSTED = "trusted"
|
||||
PENDING_CONSENT = "pending_consent"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""In-memory record for a discovered skill (PRD §4.2)."""
|
||||
|
||||
name: str
|
||||
"""Skill name from SKILL.md frontmatter."""
|
||||
|
||||
description: str
|
||||
"""Skill description from SKILL.md frontmatter."""
|
||||
|
||||
location: Path
|
||||
"""Absolute path to SKILL.md."""
|
||||
|
||||
base_dir: Path
|
||||
"""Parent directory of SKILL.md (skill root)."""
|
||||
|
||||
source_scope: SkillScope
|
||||
"""Which scope this skill was found in."""
|
||||
|
||||
trust_status: TrustStatus = TrustStatus.TRUSTED
|
||||
"""Trust state; project-scope skills start as PENDING_CONSENT before gating."""
|
||||
|
||||
# Optional frontmatter fields
|
||||
license: str | None = None
|
||||
compatibility: list[str] = field(default_factory=list)
|
||||
allowed_tools: list[str] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
@@ -13,6 +13,8 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum name length before a warning is logged
|
||||
@@ -74,17 +76,38 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.error("Failed to read %s: %s", path, exc)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to read '{path}'",
|
||||
why=str(exc),
|
||||
fix="Check the file exists and has read permissions.",
|
||||
)
|
||||
return None
|
||||
|
||||
if not content.strip():
|
||||
logger.error("Empty SKILL.md: %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="The file exists but contains no content.",
|
||||
fix="Add valid YAML frontmatter and a markdown body to the SKILL.md.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Split on --- delimiters (first two occurrences)
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
logger.error("SKILL.md missing YAML frontmatter delimiters (---): %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="Missing YAML frontmatter (---).",
|
||||
fix="Wrap the frontmatter with --- on its own line at the top and bottom.",
|
||||
)
|
||||
return None
|
||||
|
||||
# parts[0] is content before first --- (should be empty or whitespace)
|
||||
@@ -94,7 +117,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
body = parts[2].strip()
|
||||
|
||||
if not raw_yaml:
|
||||
logger.error("Empty YAML frontmatter in %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="The --- delimiters are present but the YAML block is empty.",
|
||||
fix="Add at least 'name' and 'description' fields to the frontmatter.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse YAML
|
||||
@@ -108,19 +138,47 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
try:
|
||||
fixed = _try_fix_yaml(raw_yaml)
|
||||
frontmatter = yaml.safe_load(fixed)
|
||||
logger.warning("Fixed YAML parse issues in %s (unquoted colons)", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_YAML_FIXUP,
|
||||
what=f"Auto-fixed YAML in '{path}'",
|
||||
why="Unquoted colon values detected in frontmatter.",
|
||||
fix='Wrap values containing colons in quotes e.g. description: "Use for: research"',
|
||||
)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.error("Unparseable YAML in %s: %s", path, exc)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why=str(exc),
|
||||
fix="Validate the YAML frontmatter at https://yaml-online-parser.appspot.com/",
|
||||
)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter, dict):
|
||||
logger.error("YAML frontmatter is not a mapping in %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="YAML frontmatter is not a key-value mapping.",
|
||||
fix="Ensure the frontmatter is valid YAML with key: value pairs.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Required: description
|
||||
description = frontmatter.get("description")
|
||||
if not description or not str(description).strip():
|
||||
logger.error("Missing or empty 'description' in %s — skipping skill", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_MISSING_DESCRIPTION,
|
||||
what=f"Missing 'description' in '{path}'",
|
||||
why="The 'description' field is required but is absent or empty.",
|
||||
fix="Add a non-empty 'description' field to the YAML frontmatter.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Required: name (fallback to parent directory name)
|
||||
@@ -128,7 +186,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
parent_dir_name = path.parent.name
|
||||
if not name or not str(name).strip():
|
||||
name = parent_dir_name
|
||||
logger.warning("Missing 'name' in %s — using directory name '%s'", path, name)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NAME_MISMATCH,
|
||||
what=f"Missing 'name' in '{path}' — using directory name '{name}'",
|
||||
why="The 'name' field is absent from the YAML frontmatter.",
|
||||
fix=f"Add 'name: {name}' to the frontmatter to make this explicit.",
|
||||
)
|
||||
else:
|
||||
name = str(name).strip()
|
||||
|
||||
@@ -137,11 +202,13 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
logger.warning("Skill name exceeds %d chars in %s: '%s'", _MAX_NAME_LENGTH, path, name)
|
||||
|
||||
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
|
||||
logger.warning(
|
||||
"Skill name '%s' doesn't match parent directory '%s' in %s",
|
||||
name,
|
||||
parent_dir_name,
|
||||
path,
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NAME_MISMATCH,
|
||||
what=f"Name mismatch in '{path}'",
|
||||
why=f"Skill name '{name}' doesn't match directory '{parent_dir_name}'.",
|
||||
fix=f"Rename the directory to '{name}' or set name to '{parent_dir_name}'.",
|
||||
)
|
||||
|
||||
return ParsedSkill(
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Structured error codes and diagnostics for the Hive skill system.
|
||||
|
||||
Implements DX-1 (structured error codes) and DX-2 (what/why/fix format)
|
||||
from the skill system PRD §7.5.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SkillErrorCode(Enum):
|
||||
"""Standardized error codes for skill system operations (DX-1)."""
|
||||
|
||||
SKILL_NOT_FOUND = "SKILL_NOT_FOUND"
|
||||
SKILL_PARSE_ERROR = "SKILL_PARSE_ERROR"
|
||||
SKILL_ACTIVATION_FAILED = "SKILL_ACTIVATION_FAILED"
|
||||
SKILL_MISSING_DESCRIPTION = "SKILL_MISSING_DESCRIPTION"
|
||||
SKILL_YAML_FIXUP = "SKILL_YAML_FIXUP"
|
||||
SKILL_NAME_MISMATCH = "SKILL_NAME_MISMATCH"
|
||||
SKILL_COLLISION = "SKILL_COLLISION"
|
||||
|
||||
|
||||
class SkillError(Exception):
|
||||
"""Structured exception for skill system errors (DX-2).
|
||||
|
||||
Raised in strict validation paths. Also used as the base
|
||||
format contract for log_skill_error() log messages.
|
||||
"""
|
||||
|
||||
def __init__(self, code: SkillErrorCode, what: str, why: str, fix: str):
|
||||
self.code = code
|
||||
self.what = what
|
||||
self.why = why
|
||||
self.fix = fix
|
||||
self.message = (
|
||||
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def log_skill_error(
|
||||
logger: logging.Logger,
|
||||
level: str,
|
||||
code: SkillErrorCode,
|
||||
what: str,
|
||||
why: str,
|
||||
fix: str,
|
||||
) -> None:
|
||||
"""Emit a structured skill diagnostic log with consistent format (DX-2).
|
||||
|
||||
Args:
|
||||
logger: The module logger to emit to.
|
||||
level: Log level string — 'error', 'warning', or 'info'.
|
||||
code: Structured error code.
|
||||
what: What failed (specific skill name and path).
|
||||
why: Root cause.
|
||||
fix: Concrete next step for the developer.
|
||||
"""
|
||||
msg = f"[{code.value}] What failed: {what} | Why: {why} | Fix: {fix}"
|
||||
getattr(logger, level)(
|
||||
msg,
|
||||
extra={
|
||||
"skill_error_code": code.value,
|
||||
"what": what,
|
||||
"why": why,
|
||||
"fix": fix,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""Trust gating for project-level skills (PRD AS-13).
|
||||
|
||||
Project-level skills from untrusted repositories require explicit user consent
|
||||
before their instructions are loaded into the agent's system prompt.
|
||||
Framework and user-scope skills are always trusted.
|
||||
|
||||
Trusted repos are persisted at ~/.hive/trusted_repos.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Env var to bypass trust gating in CI/headless pipelines (opt-in).
|
||||
_ENV_TRUST_ALL = "HIVE_TRUST_PROJECT_SKILLS"
|
||||
|
||||
# Env var for comma-separated own-remote glob patterns (e.g. "github.com/myorg/*").
|
||||
_ENV_OWN_REMOTES = "HIVE_OWN_REMOTES"
|
||||
|
||||
_TRUSTED_REPOS_PATH = Path.home() / ".hive" / "trusted_repos.json"
|
||||
_NOTICE_SENTINEL_PATH = Path.home() / ".hive" / ".skill_trust_notice_shown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trusted repo store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrustedRepoEntry:
|
||||
repo_key: str
|
||||
added_at: datetime
|
||||
project_path: str = ""
|
||||
|
||||
|
||||
class TrustedRepoStore:
|
||||
"""Persists permanently-trusted repo keys to ~/.hive/trusted_repos.json."""
|
||||
|
||||
def __init__(self, path: Path | None = None) -> None:
|
||||
self._path = path or _TRUSTED_REPOS_PATH
|
||||
self._entries: dict[str, TrustedRepoEntry] = {}
|
||||
self._loaded = False
|
||||
|
||||
def is_trusted(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
return repo_key in self._entries
|
||||
|
||||
def trust(self, repo_key: str, project_path: str = "") -> None:
|
||||
self._ensure_loaded()
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=datetime.now(tz=UTC),
|
||||
project_path=project_path,
|
||||
)
|
||||
self._save()
|
||||
logger.info("skill_trust_store: trusted repo_key=%s", repo_key)
|
||||
|
||||
def revoke(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
if repo_key in self._entries:
|
||||
del self._entries[repo_key]
|
||||
self._save()
|
||||
logger.info("skill_trust_store: revoked repo_key=%s", repo_key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_entries(self) -> list[TrustedRepoEntry]:
|
||||
self._ensure_loaded()
|
||||
return list(self._entries.values())
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
if not self._loaded:
|
||||
self._load()
|
||||
self._loaded = True
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
data = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
for raw in data.get("entries", []):
|
||||
repo_key = raw.get("repo_key", "")
|
||||
if not repo_key:
|
||||
continue
|
||||
try:
|
||||
added_at = datetime.fromisoformat(raw["added_at"])
|
||||
except (KeyError, ValueError):
|
||||
added_at = datetime.now(tz=UTC)
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=added_at,
|
||||
project_path=raw.get("project_path", ""),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"skill_trust_store: could not read %s (%s); treating as empty",
|
||||
self._path,
|
||||
e,
|
||||
)
|
||||
|
||||
def _save(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"version": 1,
|
||||
"entries": [
|
||||
{
|
||||
"repo_key": e.repo_key,
|
||||
"added_at": e.added_at.isoformat(),
|
||||
"project_path": e.project_path,
|
||||
}
|
||||
for e in self._entries.values()
|
||||
],
|
||||
}
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp = self._path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
tmp.replace(self._path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProjectTrustClassification(StrEnum):
|
||||
ALWAYS_TRUSTED = "always_trusted"
|
||||
TRUSTED_BY_USER = "trusted_by_user"
|
||||
UNTRUSTED = "untrusted"
|
||||
|
||||
|
||||
class ProjectTrustDetector:
|
||||
"""Classifies a project directory as trusted or untrusted.
|
||||
|
||||
Algorithm (PRD §4.1 trust note):
|
||||
1. No project_dir → ALWAYS_TRUSTED
|
||||
2. No .git directory → ALWAYS_TRUSTED (not a git repo)
|
||||
3. No remote 'origin' → ALWAYS_TRUSTED (local-only repo)
|
||||
4. Remote URL → repo_key; in TrustedRepoStore → TRUSTED_BY_USER
|
||||
5. Localhost remote → ALWAYS_TRUSTED
|
||||
6. ~/.hive/own_remotes match → ALWAYS_TRUSTED
|
||||
7. HIVE_OWN_REMOTES env match → ALWAYS_TRUSTED
|
||||
8. None of the above → UNTRUSTED
|
||||
"""
|
||||
|
||||
def __init__(self, store: TrustedRepoStore | None = None) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
|
||||
def classify(self, project_dir: Path | None) -> tuple[ProjectTrustClassification, str]:
|
||||
"""Return (classification, repo_key).
|
||||
|
||||
repo_key is empty string for ALWAYS_TRUSTED cases without a remote.
|
||||
"""
|
||||
if project_dir is None or not project_dir.exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
if not (project_dir / ".git").exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
remote_url = self._get_remote_origin(project_dir)
|
||||
if not remote_url:
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
repo_key = _normalize_remote_url(remote_url)
|
||||
|
||||
# Explicitly trusted by user
|
||||
if self._store.is_trusted(repo_key):
|
||||
return ProjectTrustClassification.TRUSTED_BY_USER, repo_key
|
||||
|
||||
# Localhost remotes are always trusted
|
||||
if _is_localhost_remote(remote_url):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
# User-configured own-remote patterns
|
||||
if self._matches_own_remotes(repo_key):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
return ProjectTrustClassification.UNTRUSTED, repo_key
|
||||
|
||||
def _get_remote_origin(self, project_dir: Path) -> str:
|
||||
"""Run git remote get-url origin. Returns empty string on any failure."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(project_dir), "remote", "get-url", "origin"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
"skill_trust: git remote lookup timed out for %s; treating as trusted",
|
||||
project_dir,
|
||||
)
|
||||
except (FileNotFoundError, OSError):
|
||||
pass # git not found or other OS error
|
||||
return ""
|
||||
|
||||
def _matches_own_remotes(self, repo_key: str) -> bool:
|
||||
"""Check repo_key against user-configured own-remote glob patterns."""
|
||||
import fnmatch
|
||||
|
||||
patterns: list[str] = []
|
||||
|
||||
# From env var
|
||||
env_patterns = _ENV_OWN_REMOTES
|
||||
import os
|
||||
|
||||
raw = os.environ.get(env_patterns, "")
|
||||
if raw:
|
||||
patterns.extend(p.strip() for p in raw.split(",") if p.strip())
|
||||
|
||||
# From ~/.hive/own_remotes file
|
||||
own_remotes_file = Path.home() / ".hive" / "own_remotes"
|
||||
if own_remotes_file.is_file():
|
||||
try:
|
||||
for line in own_remotes_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
patterns.append(line)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return any(fnmatch.fnmatch(repo_key, p) for p in patterns)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL helpers (public so CLI can reuse)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_remote_url(url: str) -> str:
|
||||
"""Normalize a git remote URL to a canonical ``host/org/repo`` key.
|
||||
|
||||
Examples:
|
||||
git@github.com:org/repo.git → github.com/org/repo
|
||||
https://github.com/org/repo → github.com/org/repo
|
||||
ssh://git@github.com/org/repo.git → github.com/org/repo
|
||||
"""
|
||||
url = url.strip()
|
||||
|
||||
# SCP-style SSH: git@github.com:org/repo.git
|
||||
if url.startswith("git@") and ":" in url and "://" not in url:
|
||||
url = url[4:] # strip git@
|
||||
url = url.replace(":", "/", 1)
|
||||
elif "://" in url:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ""
|
||||
path = parsed.path.lstrip("/")
|
||||
url = f"{host}/{path}"
|
||||
|
||||
# Strip .git suffix
|
||||
if url.endswith(".git"):
|
||||
url = url[:-4]
|
||||
|
||||
return url.lower().strip("/")
|
||||
|
||||
|
||||
def _is_localhost_remote(remote_url: str) -> bool:
|
||||
"""Return True if the remote points to a local host."""
|
||||
local_hosts = {"localhost", "127.0.0.1", "::1"}
|
||||
try:
|
||||
if "://" in remote_url:
|
||||
parsed = urlparse(remote_url)
|
||||
return (parsed.hostname or "").lower() in local_hosts
|
||||
# SCP-style: git@localhost:org/repo
|
||||
if "@" in remote_url:
|
||||
host_part = remote_url.split("@", 1)[1].split(":")[0]
|
||||
return host_part.lower() in local_hosts
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TrustGate:
|
||||
"""Filters skill list, running consent flow for untrusted project-scope skills.
|
||||
|
||||
Framework and user-scope skills are always allowed through.
|
||||
Project-scope skills from untrusted repos require consent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: TrustedRepoStore | None = None,
|
||||
detector: ProjectTrustDetector | None = None,
|
||||
interactive: bool = True,
|
||||
print_fn: Callable[[str], None] | None = None,
|
||||
input_fn: Callable[[str], str] | None = None,
|
||||
) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
self._detector = detector or ProjectTrustDetector(self._store)
|
||||
self._interactive = interactive
|
||||
self._print = print_fn or print
|
||||
self._input = input_fn or input
|
||||
|
||||
def filter_and_gate(
|
||||
self,
|
||||
skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
) -> list[ParsedSkill]:
|
||||
"""Return the subset of skills that are trusted for loading.
|
||||
|
||||
- Framework and user-scope skills: always included.
|
||||
- Project-scope skills: classified; consent prompt shown if untrusted.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Separate project skills from always-trusted scopes
|
||||
always_trusted = [s for s in skills if s.source_scope != "project"]
|
||||
project_skills = [s for s in skills if s.source_scope == "project"]
|
||||
|
||||
if not project_skills:
|
||||
return always_trusted
|
||||
|
||||
# Env-var CI override: trust all project skills for this invocation
|
||||
if os.environ.get(_ENV_TRUST_ALL, "").strip() == "1":
|
||||
logger.info(
|
||||
"skill_trust: %s=1 set; trusting %d project skill(s) without consent",
|
||||
_ENV_TRUST_ALL,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
classification, repo_key = self._detector.classify(project_dir)
|
||||
|
||||
if classification in (
|
||||
ProjectTrustClassification.ALWAYS_TRUSTED,
|
||||
ProjectTrustClassification.TRUSTED_BY_USER,
|
||||
):
|
||||
logger.info(
|
||||
"skill_trust: project skills trusted classification=%s repo=%s count=%d",
|
||||
classification,
|
||||
repo_key or "(no remote)",
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
# UNTRUSTED — need consent
|
||||
if not self._interactive or not sys.stdin.isatty():
|
||||
logger.warning(
|
||||
"skill_trust: skipping %d project-scope skill(s) from untrusted repo "
|
||||
"'%s' (non-interactive mode). "
|
||||
"To trust permanently run: hive skill trust %s",
|
||||
len(project_skills),
|
||||
repo_key,
|
||||
project_dir or ".",
|
||||
)
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=denied mode=headless",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted
|
||||
|
||||
# Interactive consent flow
|
||||
decision = self._run_consent_flow(project_skills, project_dir, repo_key)
|
||||
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=%s mode=interactive",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
decision,
|
||||
)
|
||||
|
||||
if decision == "session":
|
||||
return always_trusted + project_skills
|
||||
|
||||
if decision == "permanent":
|
||||
self._store.trust(repo_key, project_path=str(project_dir or ""))
|
||||
return always_trusted + project_skills
|
||||
|
||||
# denied
|
||||
return always_trusted
|
||||
|
||||
def _run_consent_flow(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
) -> str:
|
||||
"""Show the security notice (once) and consent prompt.
|
||||
Return 'session' | 'permanent' | 'denied'."""
|
||||
from framework.credentials.setup import Colors
|
||||
|
||||
if not sys.stdout.isatty():
|
||||
Colors.disable()
|
||||
|
||||
self._maybe_show_security_notice(Colors)
|
||||
self._print_consent_prompt(project_skills, project_dir, repo_key, Colors)
|
||||
return self._prompt_consent(Colors)
|
||||
|
||||
def _maybe_show_security_notice(self, Colors) -> None: # noqa: N803
|
||||
"""Show the one-time security notice if not already shown (NFR-5)."""
|
||||
if _NOTICE_SENTINEL_PATH.exists():
|
||||
return
|
||||
self._print("")
|
||||
self._print(
|
||||
f"{Colors.YELLOW}Security notice:{Colors.NC} Skills inject instructions "
|
||||
"into the agent's system prompt."
|
||||
)
|
||||
self._print(
|
||||
" Only load skills from sources you trust. "
|
||||
"Registry skills at tier 'verified' or 'official' have been audited."
|
||||
)
|
||||
self._print("")
|
||||
try:
|
||||
_NOTICE_SENTINEL_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
_NOTICE_SENTINEL_PATH.touch()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _print_consent_prompt(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
Colors, # noqa: N803
|
||||
) -> None:
|
||||
p = self._print
|
||||
p("")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p(f"{Colors.BOLD} SKILL TRUST REQUIRED{Colors.NC}")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p("")
|
||||
proj_label = str(project_dir) if project_dir else "this project"
|
||||
p(
|
||||
f" The project at {Colors.CYAN}{proj_label}{Colors.NC} wants to load "
|
||||
f"{len(project_skills)} skill(s)"
|
||||
)
|
||||
p(" that will inject instructions into the agent's system prompt.")
|
||||
if repo_key:
|
||||
p(f" Source: {Colors.BOLD}{repo_key}{Colors.NC}")
|
||||
p("")
|
||||
p(" Skills requesting access:")
|
||||
for skill in project_skills:
|
||||
p(f" {Colors.CYAN}•{Colors.NC} {Colors.BOLD}{skill.name}{Colors.NC}")
|
||||
p(f' "{skill.description}"')
|
||||
p(f" {Colors.DIM}{skill.location}{Colors.NC}")
|
||||
p("")
|
||||
p(" Options:")
|
||||
p(f" {Colors.CYAN}1){Colors.NC} Trust this session only")
|
||||
p(f" {Colors.CYAN}2){Colors.NC} Trust permanently — remember for future runs")
|
||||
p(
|
||||
f" {Colors.DIM}3) Deny"
|
||||
f" — skip all project-scope skills from this repo{Colors.NC}"
|
||||
)
|
||||
p(f"{Colors.YELLOW}{'─' * 60}{Colors.NC}")
|
||||
|
||||
def _prompt_consent(self, Colors) -> str: # noqa: N803
|
||||
"""Prompt until a valid choice is entered. Returns 'session'|'permanent'|'denied'."""
|
||||
mapping = {"1": "session", "2": "permanent", "3": "denied"}
|
||||
while True:
|
||||
try:
|
||||
choice = self._input("Select option (1-3): ").strip()
|
||||
if choice in mapping:
|
||||
return mapping[choice]
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return "denied"
|
||||
self._print(f"{Colors.RED}Invalid choice. Enter 1, 2, or 3.{Colors.NC}")
|
||||
@@ -324,6 +324,7 @@ export type EventTypeName =
|
||||
| "node_retry"
|
||||
| "edge_traversed"
|
||||
| "context_compacted"
|
||||
| "context_usage_updated"
|
||||
| "webhook_received"
|
||||
| "custom"
|
||||
| "escalation_requested"
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { memo, useState, useRef, useEffect, useMemo } from "react";
|
||||
import { Send, Square, Crown, Cpu, Check, Loader2 } from "lucide-react";
|
||||
|
||||
export interface ContextUsageEntry {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
import QuestionWidget from "@/components/QuestionWidget";
|
||||
import MultiQuestionWidget from "@/components/MultiQuestionWidget";
|
||||
import ParallelSubagentBubble, { type SubagentGroup } from "@/components/ParallelSubagentBubble";
|
||||
|
||||
export interface ChatMessage {
|
||||
id: string;
|
||||
@@ -18,6 +26,10 @@ export interface ChatMessage {
|
||||
createdAt?: number;
|
||||
/** Queen phase active when this message was created */
|
||||
phase?: "planning" | "building" | "staging" | "running";
|
||||
/** Backend node_id that produced this message — used for subagent grouping */
|
||||
nodeId?: string;
|
||||
/** Backend execution_id for this message */
|
||||
executionId?: string;
|
||||
}
|
||||
|
||||
interface ChatPanelProps {
|
||||
@@ -47,6 +59,8 @@ interface ChatPanelProps {
|
||||
onQuestionDismiss?: () => void;
|
||||
/** Queen operating phase — shown as a tag on queen messages */
|
||||
queenPhase?: "planning" | "building" | "staging" | "running";
|
||||
/** Context window usage for queen and workers */
|
||||
contextUsage?: Record<string, ContextUsageEntry>;
|
||||
}
|
||||
|
||||
const queenColor = "hsl(45,95%,58%)";
|
||||
@@ -241,7 +255,7 @@ const MessageBubble = memo(function MessageBubble({ msg, queenPhase }: { msg: Ch
|
||||
);
|
||||
}, (prev, next) => prev.msg.id === next.msg.id && prev.msg.content === next.msg.content && prev.msg.phase === next.msg.phase && prev.queenPhase === next.queenPhase);
|
||||
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase }: ChatPanelProps) {
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase, contextUsage }: ChatPanelProps) {
|
||||
const [input, setInput] = useState("");
|
||||
const [readMap, setReadMap] = useState<Record<string, number>>({});
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
@@ -260,6 +274,84 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
return true;
|
||||
});
|
||||
|
||||
// Group subagent messages into parallel bubbles.
|
||||
// A subagent message has nodeId containing ":subagent:".
|
||||
// The run only ends on hard boundaries (user messages, run_dividers)
|
||||
// so interleaved queen/tool/system messages don't fragment the bubble.
|
||||
type RenderItem =
|
||||
| { kind: "message"; msg: ChatMessage }
|
||||
| { kind: "parallel"; groupId: string; groups: SubagentGroup[] };
|
||||
|
||||
const renderItems = useMemo<RenderItem[]>(() => {
|
||||
const items: RenderItem[] = [];
|
||||
let i = 0;
|
||||
while (i < threadMessages.length) {
|
||||
const msg = threadMessages[i];
|
||||
const isSubagent = msg.nodeId?.includes(":subagent:");
|
||||
if (!isSubagent) {
|
||||
items.push({ kind: "message", msg });
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Start a subagent run. Collect all subagent messages, allowing
|
||||
// non-subagent messages in between (they render as normal items
|
||||
// before the bubble). Only break on hard boundaries.
|
||||
const subagentMsgs: ChatMessage[] = [];
|
||||
const interleaved: { idx: number; msg: ChatMessage }[] = [];
|
||||
const firstId = msg.id;
|
||||
|
||||
while (i < threadMessages.length) {
|
||||
const m = threadMessages[i];
|
||||
const isSa = m.nodeId?.includes(":subagent:");
|
||||
|
||||
if (isSa) {
|
||||
subagentMsgs.push(m);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Hard boundary — stop the run
|
||||
if (m.type === "user" || m.type === "run_divider") break;
|
||||
|
||||
// Worker message from a non-subagent node means the graph has
|
||||
// moved on to the next stage. Close the bubble even if some
|
||||
// subagents are still streaming in the background.
|
||||
if (m.role === "worker" && m.nodeId && !m.nodeId.includes(":subagent:")) break;
|
||||
|
||||
// Soft interruption (queen output, system, tool_status without
|
||||
// nodeId) — render it normally but keep the subagent run going
|
||||
interleaved.push({ idx: items.length + interleaved.length, msg: m });
|
||||
i++;
|
||||
}
|
||||
|
||||
// Emit interleaved messages first (before the bubble)
|
||||
for (const { msg: im } of interleaved) {
|
||||
items.push({ kind: "message", msg: im });
|
||||
}
|
||||
|
||||
// Build the single parallel bubble from all collected subagent msgs
|
||||
if (subagentMsgs.length > 0) {
|
||||
const byNode = new Map<string, ChatMessage[]>();
|
||||
for (const m of subagentMsgs) {
|
||||
const nid = m.nodeId!;
|
||||
if (!byNode.has(nid)) byNode.set(nid, []);
|
||||
byNode.get(nid)!.push(m);
|
||||
}
|
||||
const groups: SubagentGroup[] = [];
|
||||
for (const [nodeId, msgs] of byNode) {
|
||||
groups.push({
|
||||
nodeId,
|
||||
messages: msgs,
|
||||
contextUsage: contextUsage?.[nodeId],
|
||||
});
|
||||
}
|
||||
items.push({ kind: "parallel", groupId: `par-${firstId}`, groups });
|
||||
}
|
||||
}
|
||||
return items;
|
||||
}, [threadMessages, contextUsage]);
|
||||
|
||||
// Mark current thread as read
|
||||
useEffect(() => {
|
||||
const count = messages.filter((m) => m.thread === activeThread).length;
|
||||
@@ -305,11 +397,17 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
|
||||
{/* Messages */}
|
||||
<div ref={scrollRef} onScroll={handleScroll} className="flex-1 overflow-auto px-5 py-4 space-y-3">
|
||||
{threadMessages.map((msg) => (
|
||||
<div key={msg.id}>
|
||||
<MessageBubble msg={msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
))}
|
||||
{renderItems.map((item) =>
|
||||
item.kind === "parallel" ? (
|
||||
<div key={item.groupId}>
|
||||
<ParallelSubagentBubble groupId={item.groupId} groups={item.groups} />
|
||||
</div>
|
||||
) : (
|
||||
<div key={item.msg.id}>
|
||||
<MessageBubble msg={item.msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
|
||||
{/* Show typing indicator while waiting for first queen response (disabled + empty chat) */}
|
||||
{(isWaiting || (disabled && threadMessages.length === 0)) && (
|
||||
@@ -356,6 +454,57 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
<div ref={bottomRef} />
|
||||
</div>
|
||||
|
||||
{/* Context window usage bar — sits between messages and input */}
|
||||
{(() => {
|
||||
if (!contextUsage) return null;
|
||||
const queenUsage = contextUsage["__queen__"];
|
||||
const workerEntries = Object.entries(contextUsage).filter(([k]) => k !== "__queen__");
|
||||
const workerUsage = workerEntries.length > 0
|
||||
? workerEntries.reduce((best, [, v]) => (v.usagePct > best.usagePct ? v : best), workerEntries[0][1])
|
||||
: undefined;
|
||||
if (!queenUsage && !workerUsage) return null;
|
||||
return (
|
||||
<div className="flex items-center gap-3 mx-4 px-3 py-1 rounded-lg bg-muted/30 border border-border/20 group/ctx flex-shrink-0">
|
||||
{queenUsage && (
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0" title={`Queen: ${(queenUsage.estimatedTokens / 1000).toFixed(1)}k / ${(queenUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${queenUsage.messageCount} messages`}>
|
||||
<Crown className="w-3 h-3 flex-shrink-0" style={{ color: "hsl(45,95%,58%)" }} />
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(queenUsage.usagePct, 100)}%`,
|
||||
backgroundColor: queenUsage.usagePct >= 90 ? "hsl(0,65%,55%)" : queenUsage.usagePct >= 70 ? "hsl(35,90%,55%)" : "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">{queenUsage.usagePct}%</span>
|
||||
<span className="hidden group-hover/ctx:inline">{(queenUsage.estimatedTokens / 1000).toFixed(1)}k / {(queenUsage.maxTokens / 1000).toFixed(0)}k</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{workerUsage && (
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0" title={`Worker: ${(workerUsage.estimatedTokens / 1000).toFixed(1)}k / ${(workerUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${workerUsage.messageCount} messages`}>
|
||||
<Cpu className="w-3 h-3 flex-shrink-0" style={{ color: "hsl(220,60%,55%)" }} />
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(workerUsage.usagePct, 100)}%`,
|
||||
backgroundColor: workerUsage.usagePct >= 90 ? "hsl(0,65%,55%)" : workerUsage.usagePct >= 70 ? "hsl(35,90%,55%)" : "hsl(220,60%,55%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">{workerUsage.usagePct}%</span>
|
||||
<span className="hidden group-hover/ctx:inline">{(workerUsage.estimatedTokens / 1000).toFixed(1)}k / {(workerUsage.maxTokens / 1000).toFixed(0)}k</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{/* Input area — question widget replaces textarea when a question is pending */}
|
||||
{pendingQuestions && pendingQuestions.length >= 2 && onMultiQuestionSubmit ? (
|
||||
<MultiQuestionWidget
|
||||
|
||||
@@ -28,6 +28,13 @@ export interface SubagentReport {
|
||||
status?: "running" | "complete" | "error";
|
||||
}
|
||||
|
||||
interface ContextUsage {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
interface NodeDetailPanelProps {
|
||||
node: GraphNode | null;
|
||||
nodeSpec?: NodeSpec | null;
|
||||
@@ -38,6 +45,7 @@ interface NodeDetailPanelProps {
|
||||
workerSessionId?: string | null;
|
||||
nodeLogs?: string[];
|
||||
actionPlan?: string;
|
||||
contextUsage?: ContextUsage;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
@@ -309,7 +317,7 @@ const tabs: { id: Tab; label: string; Icon: React.FC<{ className?: string }> }[]
|
||||
{ id: "subagents", label: "Subagents", Icon: ({ className }) => <Bot className={className} /> },
|
||||
];
|
||||
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, contextUsage, onClose }: NodeDetailPanelProps) {
|
||||
const [activeTab, setActiveTab] = useState<Tab>("overview");
|
||||
const [realTools, setRealTools] = useState<ToolInfo[] | null>(null);
|
||||
const [realCriteria, setRealCriteria] = useState<NodeCriteria | null>(null);
|
||||
@@ -389,6 +397,43 @@ export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagent
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Context window usage */}
|
||||
{contextUsage && (
|
||||
<div className="px-4 py-2 border-b border-border/20 flex-shrink-0">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="text-[10px] text-muted-foreground font-medium">Context</span>
|
||||
<span className="text-[10px] text-muted-foreground/70 ml-auto">
|
||||
{(contextUsage.estimatedTokens / 1000).toFixed(1)}k / {(contextUsage.maxTokens / 1000).toFixed(0)}k tokens
|
||||
</span>
|
||||
</div>
|
||||
<div className="w-full h-1.5 rounded-full bg-muted/50 overflow-hidden">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(contextUsage.usagePct, 100)}%`,
|
||||
backgroundColor: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center gap-2 mt-1">
|
||||
<span className="text-[10px] text-muted-foreground/60">{contextUsage.messageCount} messages</span>
|
||||
<span className="text-[10px] font-medium ml-auto" style={{
|
||||
color: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}>
|
||||
{contextUsage.usagePct}%
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Tab bar */}
|
||||
<div className="flex border-b border-border/30 flex-shrink-0 px-2 pt-1 overflow-x-auto scrollbar-hide">
|
||||
{tabs.filter(t => t.id !== "subagents" || (nodeSpec?.sub_agents && nodeSpec.sub_agents.length > 0)).map(tab => (
|
||||
|
||||
@@ -0,0 +1,413 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { ChevronDown, ChevronUp, Cpu } from "lucide-react";
|
||||
import type { ChatMessage, ContextUsageEntry } from "@/components/ChatPanel";
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const workerColor = "hsl(220,60%,55%)";
|
||||
|
||||
const SUBAGENT_COLORS = [
|
||||
"hsl(220,60%,55%)",
|
||||
"hsl(260,50%,55%)",
|
||||
"hsl(180,50%,45%)",
|
||||
"hsl(30,70%,50%)",
|
||||
"hsl(340,55%,50%)",
|
||||
"hsl(150,45%,45%)",
|
||||
"hsl(45,80%,50%)",
|
||||
"hsl(290,45%,55%)",
|
||||
];
|
||||
|
||||
function colorForIndex(i: number): string {
|
||||
return SUBAGENT_COLORS[i % SUBAGENT_COLORS.length];
|
||||
}
|
||||
|
||||
function subagentLabel(nodeId: string): string {
|
||||
const parts = nodeId.split(":subagent:");
|
||||
const raw = parts.length >= 2 ? parts[1] : nodeId;
|
||||
return raw
|
||||
.replace(/:\d+$/, "") // strip instance suffix like ":3"
|
||||
.replace(/[_-]/g, " ")
|
||||
.replace(/\b\w/g, (c) => c.toUpperCase())
|
||||
.trim();
|
||||
}
|
||||
|
||||
function last<T>(arr: T[]): T | undefined {
|
||||
return arr[arr.length - 1];
|
||||
}
|
||||
|
||||
export interface SubagentGroup {
|
||||
nodeId: string;
|
||||
messages: ChatMessage[];
|
||||
contextUsage?: ContextUsageEntry;
|
||||
}
|
||||
|
||||
interface ParallelSubagentBubbleProps {
|
||||
groups: SubagentGroup[];
|
||||
groupId: string;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Thermometer — vertical context gauge on right edge of each pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool overlay — shown when a tool_status message is active (not all done)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ToolOverlay({
|
||||
toolName,
|
||||
color,
|
||||
visible,
|
||||
}: {
|
||||
toolName: string;
|
||||
color: string;
|
||||
visible: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
className="absolute inset-0 top-[22px] flex items-center justify-center transition-opacity duration-200 z-10"
|
||||
style={{
|
||||
background: "rgba(8,8,14,0.82)",
|
||||
opacity: visible ? 1 : 0,
|
||||
pointerEvents: visible ? "auto" : "none",
|
||||
}}
|
||||
>
|
||||
<div className="text-center px-3 py-2 rounded-md border" style={{ borderColor: `${color}40` }}>
|
||||
<div className="text-[10px] font-medium" style={{ color }}>
|
||||
{toolName}
|
||||
</div>
|
||||
<div className="text-[11px] mt-0.5" style={{ color }}>
|
||||
{visible ? "..." : "\u2713"}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Single tmux pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function MuxPane({
|
||||
group,
|
||||
index,
|
||||
label,
|
||||
isFocused,
|
||||
isZoomed,
|
||||
onClickTitle,
|
||||
}: {
|
||||
group: SubagentGroup;
|
||||
index: number;
|
||||
label: string;
|
||||
isFocused: boolean;
|
||||
isZoomed: boolean;
|
||||
onClickTitle: () => void;
|
||||
}) {
|
||||
const bodyRef = useRef<HTMLDivElement>(null);
|
||||
const stickRef = useRef(true);
|
||||
const color = colorForIndex(index);
|
||||
const pct = group.contextUsage?.usagePct ?? 0;
|
||||
|
||||
const streamMsgs = group.messages.filter((m) => m.type !== "tool_status");
|
||||
const latestContent = last(streamMsgs)?.content ?? "";
|
||||
const msgCount = streamMsgs.length;
|
||||
|
||||
// Detect active tool and finished state from latest tool_status
|
||||
const latestTool = last(
|
||||
group.messages.filter((m) => m.type === "tool_status")
|
||||
);
|
||||
let activeToolName = "";
|
||||
let toolRunning = false;
|
||||
let isFinished = false;
|
||||
if (latestTool) {
|
||||
try {
|
||||
const parsed = JSON.parse(latestTool.content);
|
||||
const tools: { name: string; done: boolean }[] = parsed.tools || [];
|
||||
const allDone = parsed.allDone as boolean | undefined;
|
||||
const running = tools.find((t) => !t.done);
|
||||
if (running) {
|
||||
activeToolName = running.name;
|
||||
toolRunning = true;
|
||||
}
|
||||
// Finished when all tools are done and one of them is set_output
|
||||
// or report_to_parent (terminal tool calls)
|
||||
if (allDone && tools.length > 0) {
|
||||
const hasTerminal = tools.some(
|
||||
(t) =>
|
||||
t.done &&
|
||||
(t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
if (hasTerminal) isFinished = true;
|
||||
}
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-scroll
|
||||
useEffect(() => {
|
||||
if (stickRef.current && bodyRef.current) {
|
||||
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
|
||||
}
|
||||
}, [latestContent]);
|
||||
|
||||
const handleScroll = () => {
|
||||
const el = bodyRef.current;
|
||||
if (!el) return;
|
||||
stickRef.current = el.scrollHeight - el.scrollTop - el.clientHeight < 30;
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col min-h-0 overflow-hidden relative transition-all duration-200"
|
||||
style={{
|
||||
borderWidth: 1,
|
||||
borderStyle: "solid",
|
||||
borderColor: isFocused && !isFinished ? `${color}60` : "transparent",
|
||||
opacity: isFinished ? 0.4 : isFocused || isZoomed ? 1 : 0.55,
|
||||
...(isZoomed
|
||||
? { gridColumn: "1 / -1", gridRow: "1 / -1", zIndex: 10 }
|
||||
: {}),
|
||||
}}
|
||||
>
|
||||
{/* Title bar */}
|
||||
<div
|
||||
className="flex items-center gap-1.5 px-2 py-[3px] flex-shrink-0 cursor-pointer select-none"
|
||||
style={{ background: "#0e0e16", borderBottom: "1px solid #1a1a2a" }}
|
||||
onClick={onClickTitle}
|
||||
>
|
||||
{isFinished ? (
|
||||
<span className="text-[8px] flex-shrink-0 leading-none" style={{ color: "#4a4" }}>✓</span>
|
||||
) : (
|
||||
<div
|
||||
className="w-[6px] h-[6px] rounded-full flex-shrink-0"
|
||||
style={{ background: color }}
|
||||
/>
|
||||
)}
|
||||
<span className="text-[9px] flex-shrink-0" style={{ color: isFinished ? "#555" : color }}>
|
||||
{label}
|
||||
</span>
|
||||
<span className="flex-1" />
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{msgCount}
|
||||
</span>
|
||||
<div
|
||||
className="w-[36px] h-[3px] rounded-full overflow-hidden flex-shrink-0"
|
||||
style={{ background: "#1a1a2a" }}
|
||||
>
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500"
|
||||
style={{
|
||||
width: `${Math.min(pct, 100)}%`,
|
||||
backgroundColor:
|
||||
pct >= 80 ? "hsl(0,65%,55%)" : pct >= 50 ? "hsl(35,90%,55%)" : color,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{pct}%
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Body */}
|
||||
<div
|
||||
ref={bodyRef}
|
||||
onScroll={handleScroll}
|
||||
className="flex-1 min-h-0 overflow-y-auto px-2 py-1 text-[10px] leading-[1.7]"
|
||||
style={{ background: "#08080e", color: "#555", fontFamily: "monospace" }}
|
||||
>
|
||||
{latestContent ? (
|
||||
<div style={{ color: "#ccc" }}>
|
||||
<MarkdownContent content={latestContent} />
|
||||
</div>
|
||||
) : (
|
||||
<span style={{ color: "#333" }}>waiting...</span>
|
||||
)}
|
||||
{/* Blinking cursor — hidden when finished */}
|
||||
{!isFinished && (
|
||||
<span
|
||||
className="inline-block w-[6px] h-[11px] align-middle ml-0.5"
|
||||
style={{
|
||||
background: color,
|
||||
animation: "cursorBlink 1s step-end infinite",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Tool overlay */}
|
||||
<ToolOverlay
|
||||
toolName={activeToolName}
|
||||
color={color}
|
||||
visible={toolRunning}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const ParallelSubagentBubble = memo(
|
||||
function ParallelSubagentBubble({ groups }: ParallelSubagentBubbleProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [zoomedIdx, setZoomedIdx] = useState<number | null>(null);
|
||||
|
||||
// Labels with instance numbers for duplicates
|
||||
const labels: string[] = (() => {
|
||||
const countByBase = new Map<string, number>();
|
||||
const bases = groups.map((g) => subagentLabel(g.nodeId));
|
||||
for (const b of bases)
|
||||
countByBase.set(b, (countByBase.get(b) ?? 0) + 1);
|
||||
const idxByBase = new Map<string, number>();
|
||||
return bases.map((b) => {
|
||||
if ((countByBase.get(b) ?? 1) <= 1) return b;
|
||||
const idx = (idxByBase.get(b) ?? 0) + 1;
|
||||
idxByBase.set(b, idx);
|
||||
return `${b} #${idx}`;
|
||||
});
|
||||
})();
|
||||
|
||||
// Latest-active pane
|
||||
const latestIdx = groups.reduce<number>((best, g, i) => {
|
||||
const filtered = g.messages.filter((m) => m.type !== "tool_status");
|
||||
const lm = last(filtered);
|
||||
if (!lm) return best;
|
||||
if (best < 0) return i;
|
||||
const bm = last(
|
||||
groups[best].messages.filter((m) => m.type !== "tool_status")
|
||||
);
|
||||
if (!bm) return i;
|
||||
return (lm.createdAt ?? 0) >= (bm.createdAt ?? 0) ? i : best;
|
||||
}, -1);
|
||||
|
||||
// Per-group finished detection (same logic as MuxPane)
|
||||
const finishedFlags = groups.map((g) => {
|
||||
const lt = last(g.messages.filter((m) => m.type === "tool_status"));
|
||||
if (!lt) return false;
|
||||
try {
|
||||
const p = JSON.parse(lt.content);
|
||||
const tools: { name: string; done: boolean }[] = p.tools || [];
|
||||
if (!p.allDone || tools.length === 0) return false;
|
||||
return tools.some(
|
||||
(t) => t.done && (t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
} catch { return false; }
|
||||
});
|
||||
const activeCount = finishedFlags.filter((f) => !f).length;
|
||||
|
||||
if (groups.length === 0) return null;
|
||||
|
||||
// Grid sizing: 2 columns, auto rows capped at a fixed height
|
||||
const rows = Math.ceil(groups.length / 2);
|
||||
const gridHeight = expanded
|
||||
? Math.min(rows * 200, 480)
|
||||
: Math.min(rows * 100, 240);
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
{/* Left icon */}
|
||||
<div
|
||||
className="flex-shrink-0 w-7 h-7 rounded-xl flex items-center justify-center mt-1"
|
||||
style={{
|
||||
backgroundColor: `${workerColor}18`,
|
||||
border: `1.5px solid ${workerColor}35`,
|
||||
}}
|
||||
>
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color: workerColor }} />
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0 max-w-[90%]">
|
||||
{/* Header */}
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="font-medium text-xs" style={{ color: workerColor }}>
|
||||
{groups.length === 1 ? "Sub-agent" : "Parallel Agents"}
|
||||
</span>
|
||||
<span className="text-[10px] font-medium px-1.5 py-0.5 rounded-md bg-muted text-muted-foreground">
|
||||
{activeCount > 0 ? `${activeCount} running` : `${groups.length} done`}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => {
|
||||
setExpanded((v) => !v);
|
||||
setZoomedIdx(null);
|
||||
}}
|
||||
className="ml-auto text-muted-foreground/60 hover:text-muted-foreground transition-colors p-0.5 rounded"
|
||||
title={expanded ? "Collapse" : "Expand"}
|
||||
>
|
||||
{expanded ? (
|
||||
<ChevronUp className="w-3.5 h-3.5" />
|
||||
) : (
|
||||
<ChevronDown className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Mux frame */}
|
||||
<div
|
||||
className="rounded-lg overflow-hidden"
|
||||
style={{
|
||||
border: "2px solid #1a1a2a",
|
||||
background: "#08080e",
|
||||
}}
|
||||
>
|
||||
{/* Grid */}
|
||||
<div
|
||||
className="grid gap-px"
|
||||
style={{
|
||||
gridTemplateColumns:
|
||||
groups.length === 1 ? "1fr" : "1fr 1fr",
|
||||
gridTemplateRows: `repeat(${rows}, 1fr)`,
|
||||
height: gridHeight,
|
||||
background: "#111",
|
||||
}}
|
||||
>
|
||||
{groups.map((group, i) => (
|
||||
<MuxPane
|
||||
key={group.nodeId}
|
||||
group={group}
|
||||
index={i}
|
||||
label={labels[i]}
|
||||
isFocused={latestIdx === i}
|
||||
isZoomed={zoomedIdx === i}
|
||||
onClickTitle={() =>
|
||||
setZoomedIdx(zoomedIdx === i ? null : i)
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
(prev, next) =>
|
||||
prev.groupId === next.groupId &&
|
||||
prev.groups.length === next.groups.length &&
|
||||
prev.groups.every(
|
||||
(g, i) =>
|
||||
g.nodeId === next.groups[i].nodeId &&
|
||||
g.messages.length === next.groups[i].messages.length &&
|
||||
last(g.messages)?.content === last(next.groups[i].messages)?.content &&
|
||||
g.contextUsage?.usagePct === next.groups[i].contextUsage?.usagePct
|
||||
)
|
||||
);
|
||||
|
||||
export default ParallelSubagentBubble;
|
||||
|
||||
// Injected as a global style (keyframes can't be inline)
|
||||
if (typeof document !== "undefined") {
|
||||
const id = "parallel-subagent-keyframes";
|
||||
if (!document.getElementById(id)) {
|
||||
const style = document.createElement("style");
|
||||
style.id = id;
|
||||
style.textContent = `
|
||||
@keyframes cursorBlink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
|
||||
@keyframes thermoPulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.4; } }
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
}
|
||||
@@ -72,6 +72,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -110,6 +112,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -352,6 +352,8 @@ interface AgentBackendState {
|
||||
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
|
||||
/** Whether the pending question came from queen or worker */
|
||||
pendingQuestionSource: "queen" | "worker" | null;
|
||||
/** Per-node context window usage (from context_usage_updated events) */
|
||||
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
|
||||
}
|
||||
|
||||
function defaultAgentState(): AgentBackendState {
|
||||
@@ -389,6 +391,7 @@ function defaultAgentState(): AgentBackendState {
|
||||
pendingOptions: null,
|
||||
pendingQuestions: null,
|
||||
pendingQuestionSource: null,
|
||||
contextUsage: {},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -630,6 +633,10 @@ export default function Workspace() {
|
||||
// it was created in (avoids stale-closure when phase change and message
|
||||
// events arrive in the same React batch).
|
||||
const queenPhaseRef = useRef<Record<string, string>>({});
|
||||
// Accumulated queen text across inner_turns within the same iteration.
|
||||
// Key: `${agentType}:${execution_id}:${iteration}`, value: { [inner_turn]: snapshot }.
|
||||
// This lets us merge all inner_turn text into one chat bubble per iteration.
|
||||
const queenIterTextRef = useRef<Record<string, Record<number, string>>>({});
|
||||
// Timestamp when designingDraft was set — used to enforce minimum spinner duration.
|
||||
const designingDraftSinceRef = useRef<Record<string, number>>({});
|
||||
const designingDraftTimerRef = useRef<Record<string, ReturnType<typeof setTimeout>>>({});
|
||||
@@ -1707,14 +1714,29 @@ export default function Workspace() {
|
||||
if (isQueen) console.log('[QUEEN] chatMsg:', chatMsg?.id, chatMsg?.content?.slice(0, 50), 'turn:', currentTurn);
|
||||
if (chatMsg && !suppressQueenMessages) {
|
||||
// Queen emits multiple client_output_delta / llm_text_delta snapshots
|
||||
// across iterations and inner tool-loop turns. Build a stable ID that
|
||||
// groups streaming deltas for the *same* output (same execution +
|
||||
// iteration + inner_turn) into one bubble, while keeping distinct
|
||||
// outputs as separate bubbles so earlier text isn't overwritten.
|
||||
// across iterations and inner tool-loop turns. Merge all inner_turns
|
||||
// within the same iteration into ONE bubble so the queen's multi-step
|
||||
// tool loop (text → tool → text → tool → text) appears as one cohesive
|
||||
// message rather than many small fragments.
|
||||
if (isQueen && (event.type === "client_output_delta" || event.type === "llm_text_delta") && event.execution_id) {
|
||||
const iter = event.data?.iteration ?? 0;
|
||||
const inner = event.data?.inner_turn ?? 0;
|
||||
chatMsg.id = `queen-stream-${event.execution_id}-${iter}-${inner}`;
|
||||
const inner = (event.data?.inner_turn as number) ?? 0;
|
||||
const iterKey = `${agentType}:${event.execution_id}:${iter}`;
|
||||
|
||||
// Store the latest snapshot for this inner_turn
|
||||
if (!queenIterTextRef.current[iterKey]) {
|
||||
queenIterTextRef.current[iterKey] = {};
|
||||
}
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
queenIterTextRef.current[iterKey][inner] = snapshot;
|
||||
|
||||
// Concatenate all inner_turn snapshots in order
|
||||
const parts = queenIterTextRef.current[iterKey];
|
||||
const sortedInners = Object.keys(parts).map(Number).sort((a, b) => a - b);
|
||||
chatMsg.content = sortedInners.map(k => parts[k]).join("\n");
|
||||
|
||||
// Single ID per iteration — no inner_turn in the ID
|
||||
chatMsg.id = `queen-stream-${event.execution_id}-${iter}`;
|
||||
}
|
||||
if (isQueen) {
|
||||
chatMsg.role = role;
|
||||
@@ -1989,6 +2011,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -2060,6 +2084,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -2136,6 +2162,29 @@ export default function Workspace() {
|
||||
}
|
||||
break;
|
||||
|
||||
case "context_usage_updated": {
|
||||
const streamKey = isQueen ? "__queen__" : (event.node_id || streamId);
|
||||
const usagePct = (event.data?.usage_pct as number) ?? 0;
|
||||
const messageCount = (event.data?.message_count as number) ?? 0;
|
||||
const estimatedTokens = (event.data?.estimated_tokens as number) ?? 0;
|
||||
const maxTokens = (event.data?.max_context_tokens as number) ?? 0;
|
||||
setAgentStates(prev => {
|
||||
const state = prev[agentType];
|
||||
if (!state) return prev;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: {
|
||||
...state,
|
||||
contextUsage: {
|
||||
...state.contextUsage,
|
||||
[streamKey]: { usagePct, messageCount, estimatedTokens, maxTokens },
|
||||
},
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "node_action_plan":
|
||||
if (!isQueen && event.node_id) {
|
||||
const plan = (event.data?.plan as string) || "";
|
||||
@@ -3174,6 +3223,7 @@ export default function Workspace() {
|
||||
}
|
||||
onMultiQuestionSubmit={handleMultiQuestionAnswer}
|
||||
onQuestionDismiss={handleQuestionDismiss}
|
||||
contextUsage={activeAgentState?.contextUsage}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -3377,6 +3427,7 @@ export default function Workspace() {
|
||||
workerSessionId={null}
|
||||
nodeLogs={activeAgentState?.nodeLogs[resolvedSelectedNode.id] || []}
|
||||
actionPlan={activeAgentState?.nodeActionPlans[resolvedSelectedNode.id]}
|
||||
contextUsage={activeAgentState?.contextUsage[resolvedSelectedNode.id]}
|
||||
onClose={() => setSelectedNode(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -62,6 +62,10 @@ lint.isort.section-order = [
|
||||
"first-party",
|
||||
"local-folder",
|
||||
]
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning:litellm.*"
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Integration test: Run a real EventLoopNode against the Antigravity backend.
|
||||
|
||||
Run: .venv/bin/python core/tests/test_antigravity_eventloop.py
|
||||
|
||||
Requires:
|
||||
- ~/.hive/antigravity-accounts.json with valid credentials
|
||||
(run 'uv run python core/antigravity_auth.py auth account add' to authenticate)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")
|
||||
# Show our provider's retry/stream logs
|
||||
logging.getLogger("framework.llm.litellm").setLevel(logging.DEBUG)
|
||||
|
||||
from framework.config import RuntimeConfig # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
|
||||
|
||||
def make_provider() -> LiteLLMProvider:
|
||||
cfg = RuntimeConfig()
|
||||
if not cfg.api_key:
|
||||
print("ERROR: No Antigravity token found.")
|
||||
print(" 1. Run 'antigravity-auth accounts add' to authenticate.")
|
||||
print(" 2. Run 'antigravity-auth serve' to start the local proxy.")
|
||||
print(" 3. Configure Hive: run quickstart.sh and select option 7 (Antigravity).")
|
||||
sys.exit(1)
|
||||
print(f"Model : {cfg.model}")
|
||||
print(f"Base : {cfg.api_base}")
|
||||
print(f"Antigravity : {'localhost:8069' in (cfg.api_base or '')}")
|
||||
return LiteLLMProvider(
|
||||
model=cfg.model,
|
||||
api_key=cfg.api_key,
|
||||
api_base=cfg.api_base,
|
||||
**cfg.extra_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_context(
|
||||
llm: LiteLLMProvider,
|
||||
*,
|
||||
node_id: str = "test",
|
||||
system_prompt: str = "You are a helpful assistant.",
|
||||
output_keys: list[str] | None = None,
|
||||
) -> NodeContext:
|
||||
if output_keys is None:
|
||||
output_keys = ["answer"]
|
||||
|
||||
spec = NodeSpec(
|
||||
id=node_id,
|
||||
name="Test Node",
|
||||
description="Integration test node",
|
||||
node_type="event_loop",
|
||||
output_keys=output_keys,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.start_run = MagicMock(return_value="run-1")
|
||||
runtime.decide = MagicMock(return_value="dec-1")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
|
||||
memory = SharedMemory()
|
||||
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_id,
|
||||
node_spec=spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=llm,
|
||||
available_tools=[],
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
async def run_test(
|
||||
name: str, llm: LiteLLMProvider, system: str, output_keys: list[str]
|
||||
) -> NodeResult:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"TEST: {name}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
ctx = make_context(llm, system_prompt=system, output_keys=output_keys)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=3))
|
||||
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
print(f" Success : {result.success}")
|
||||
print(f" Output : {result.output}")
|
||||
if result.error:
|
||||
print(f" Error : {result.error}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return NodeResult(success=False, error=str(e))
|
||||
|
||||
|
||||
async def main():
|
||||
llm = make_provider()
|
||||
print()
|
||||
|
||||
# Test 1: Simple text output — the node should call set_output to fill "answer"
|
||||
r1 = await run_test(
|
||||
name="Simple text generation",
|
||||
llm=llm,
|
||||
system=(
|
||||
"You are a helpful assistant. When asked a question, use the "
|
||||
"set_output tool to store your answer in the 'answer' key. "
|
||||
"Keep answers short (1-2 sentences)."
|
||||
),
|
||||
output_keys=["answer"],
|
||||
)
|
||||
|
||||
# Test 2: If test 1 failed, try bare stream() to isolate the issue
|
||||
if not r1.success:
|
||||
print(f"\n{'=' * 60}")
|
||||
print("FALLBACK: Testing bare provider.stream() directly")
|
||||
print(f"{'=' * 60}")
|
||||
try:
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
text = ""
|
||||
events = []
|
||||
async for event in llm.stream(
|
||||
messages=[{"role": "user", "content": "Say hello in 3 words."}],
|
||||
):
|
||||
events.append(type(event).__name__)
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f" Finish: stop={event.stop_reason}"
|
||||
f" in={event.input_tokens}"
|
||||
f" out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f" ToolCall: {event.tool_name}")
|
||||
print(f" Text : {text!r}")
|
||||
print(f" Events : {events}")
|
||||
print(f" RESULT : {'OK' if text else 'EMPTY'}")
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("DONE")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Unit tests for MCP client transport and reconnect behavior."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from framework.runner import mcp_client as mcp_client_module
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig, MCPTool
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload=None):
|
||||
self._payload = payload or {}
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""Pretend the request succeeded."""
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _FakeHttpClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.get_calls: list[str] = []
|
||||
self.closed = False
|
||||
|
||||
def get(self, path: str) -> _FakeResponse:
|
||||
self.get_calls.append(path)
|
||||
return _FakeResponse()
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def test_connect_unix_transport_uses_socket_path(monkeypatch):
|
||||
created = {}
|
||||
|
||||
class FakeHTTPTransport:
|
||||
def __init__(self, *, uds: str):
|
||||
created["uds"] = uds
|
||||
self.uds = uds
|
||||
|
||||
def fake_client_factory(**kwargs):
|
||||
client = _FakeHttpClient(**kwargs)
|
||||
created["client"] = client
|
||||
return client
|
||||
|
||||
monkeypatch.setattr(mcp_client_module.httpx, "HTTPTransport", FakeHTTPTransport)
|
||||
monkeypatch.setattr(mcp_client_module.httpx, "Client", fake_client_factory)
|
||||
monkeypatch.setattr(MCPClient, "_discover_tools", lambda self: None)
|
||||
|
||||
client = MCPClient(
|
||||
MCPServerConfig(
|
||||
name="unix-server",
|
||||
transport="unix",
|
||||
url="http://localhost",
|
||||
socket_path="/tmp/test.sock",
|
||||
)
|
||||
)
|
||||
|
||||
client.connect()
|
||||
|
||||
assert created["uds"] == "/tmp/test.sock"
|
||||
assert client._http_client is created["client"] # noqa: SLF001 - direct unit test
|
||||
assert created["client"].kwargs["base_url"] == "http://localhost"
|
||||
assert created["client"].get_calls == ["/health"]
|
||||
|
||||
client.disconnect()
|
||||
assert created["client"].closed is True
|
||||
|
||||
|
||||
def test_connect_sse_and_list_tools(monkeypatch):
|
||||
pytest.importorskip("mcp")
|
||||
sse_module = pytest.importorskip("mcp.client.sse")
|
||||
import mcp
|
||||
|
||||
contexts = []
|
||||
|
||||
class FakeSSEContext:
|
||||
def __init__(self, url: str, headers: dict[str, str] | None, timeout: float):
|
||||
self.url = url
|
||||
self.headers = headers
|
||||
self.timeout = timeout
|
||||
self.exited = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return "read-stream", "write-stream"
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.exited = True
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, read_stream, write_stream):
|
||||
self.read_stream = read_stream
|
||||
self.write_stream = write_stream
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.closed = True
|
||||
|
||||
async def initialize(self):
|
||||
"""Pretend session initialization succeeded."""
|
||||
|
||||
async def list_tools(self):
|
||||
return SimpleNamespace(
|
||||
tools=[
|
||||
SimpleNamespace(
|
||||
name="search",
|
||||
description="Search docs",
|
||||
inputSchema={"type": "object"},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def fake_sse_client(url: str, headers=None, timeout=5, **_kwargs):
|
||||
context = FakeSSEContext(url=url, headers=headers, timeout=timeout)
|
||||
contexts.append(context)
|
||||
return context
|
||||
|
||||
monkeypatch.setattr(sse_module, "sse_client", fake_sse_client)
|
||||
monkeypatch.setattr(mcp, "ClientSession", FakeSession)
|
||||
|
||||
client = MCPClient(
|
||||
MCPServerConfig(
|
||||
name="sse-server",
|
||||
transport="sse",
|
||||
url="http://localhost/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
)
|
||||
|
||||
client.connect()
|
||||
tools = client.list_tools()
|
||||
|
||||
assert [tool.name for tool in tools] == ["search"]
|
||||
assert tools[0].description == "Search docs"
|
||||
assert contexts[0].url == "http://localhost/sse"
|
||||
assert contexts[0].headers == {"Authorization": "Bearer token"}
|
||||
assert contexts[0].timeout == 30.0
|
||||
|
||||
client.disconnect()
|
||||
assert contexts[0].exited is True
|
||||
|
||||
|
||||
def test_call_tool_retries_once_on_connect_error_for_unix(monkeypatch):
|
||||
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
|
||||
client._connected = True # noqa: SLF001 - direct unit test
|
||||
client._tools = { # noqa: SLF001 - direct unit test
|
||||
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
|
||||
}
|
||||
|
||||
first_error = httpx.ConnectError("first failure")
|
||||
calls = {"count": 0}
|
||||
reconnects = []
|
||||
|
||||
def fake_call_tool_http(tool_name, arguments):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
raise first_error
|
||||
return [{"type": "text", "text": f"{tool_name}:{arguments['value']}"}]
|
||||
|
||||
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
|
||||
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
|
||||
|
||||
result = client.call_tool("ping", {"value": "ok"})
|
||||
|
||||
assert result == [{"type": "text", "text": "ping:ok"}]
|
||||
assert calls["count"] == 2
|
||||
assert reconnects == ["reconnected"]
|
||||
|
||||
|
||||
def test_call_tool_retry_exhausted_raises_original_error_for_unix(monkeypatch):
|
||||
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
|
||||
client._connected = True # noqa: SLF001 - direct unit test
|
||||
client._tools = { # noqa: SLF001 - direct unit test
|
||||
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
|
||||
}
|
||||
|
||||
first_error = httpx.ConnectError("first failure")
|
||||
second_error = httpx.ConnectError("second failure")
|
||||
calls = {"count": 0}
|
||||
reconnects = []
|
||||
|
||||
def fake_call_tool_http(_tool_name, _arguments):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
raise first_error
|
||||
raise second_error
|
||||
|
||||
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
|
||||
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
|
||||
|
||||
with pytest.raises(httpx.ConnectError) as exc_info:
|
||||
client.call_tool("ping", {"value": "ok"})
|
||||
|
||||
assert exc_info.value is first_error
|
||||
assert calls["count"] == 2
|
||||
assert reconnects == ["reconnected"]
|
||||
|
||||
|
||||
def test_call_tool_http_preserves_runtime_error_wrapping(monkeypatch):
|
||||
client = MCPClient(MCPServerConfig(name="http-server", transport="http"))
|
||||
client._connected = True # noqa: SLF001 - direct unit test
|
||||
client._tools = { # noqa: SLF001 - direct unit test
|
||||
"ping": MCPTool("ping", "Ping tool", {}, "http-server")
|
||||
}
|
||||
|
||||
connect_error = httpx.ConnectError("first failure")
|
||||
|
||||
class FailingHttpClient:
|
||||
def post(self, _path, json):
|
||||
raise connect_error
|
||||
|
||||
client._http_client = FailingHttpClient() # noqa: SLF001 - direct unit test
|
||||
reconnects = []
|
||||
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
client.call_tool("ping", {"value": "ok"})
|
||||
|
||||
assert "Failed to call tool via HTTP" in str(exc_info.value)
|
||||
assert exc_info.value.__cause__ is connect_error
|
||||
assert reconnects == []
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Tests for the shared MCP connection manager."""
|
||||
|
||||
import threading
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from framework.runner.mcp_client import MCPServerConfig, MCPTool
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
|
||||
class FakeMCPClient:
|
||||
"""Minimal fake MCP client for connection manager tests."""
|
||||
|
||||
instances: list["FakeMCPClient"] = []
|
||||
|
||||
def __init__(self, config: MCPServerConfig):
|
||||
self.config = config
|
||||
self._connected = False
|
||||
self.connect_calls = 0
|
||||
self.disconnect_calls = 0
|
||||
self.list_tools_calls = 0
|
||||
self.list_tools_error: Exception | None = None
|
||||
FakeMCPClient.instances.append(self)
|
||||
|
||||
def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
self._connected = True
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.disconnect_calls += 1
|
||||
self._connected = False
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
self.list_tools_calls += 1
|
||||
if self.list_tools_error is not None:
|
||||
raise self.list_tools_error
|
||||
return [MCPTool("ping", "Ping", {"type": "object"}, self.config.name)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(monkeypatch):
|
||||
monkeypatch.setattr("framework.runner.mcp_connection_manager.MCPClient", FakeMCPClient)
|
||||
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
|
||||
FakeMCPClient.instances.clear()
|
||||
manager = MCPConnectionManager.get_instance()
|
||||
yield manager
|
||||
manager.cleanup_all()
|
||||
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
|
||||
FakeMCPClient.instances.clear()
|
||||
|
||||
|
||||
def test_acquire_returns_same_client_for_same_server_name(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
|
||||
client_one = manager.acquire(config)
|
||||
client_two = manager.acquire(config)
|
||||
|
||||
assert client_one is client_two
|
||||
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
|
||||
assert len(FakeMCPClient.instances) == 1
|
||||
|
||||
|
||||
def test_release_with_refcount_above_one_keeps_connection_open(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
client = manager.acquire(config)
|
||||
manager.acquire(config)
|
||||
|
||||
manager.release("shared")
|
||||
|
||||
assert client.disconnect_calls == 0
|
||||
assert manager._pool["shared"] is client # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts["shared"] == 1 # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_release_last_reference_disconnects_and_removes_from_pool(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
client = manager.acquire(config)
|
||||
|
||||
manager.release("shared")
|
||||
|
||||
assert client.disconnect_calls == 1
|
||||
assert "shared" not in manager._pool # noqa: SLF001 - state assertion for unit test
|
||||
assert "shared" not in manager._refcounts # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_concurrent_acquire_and_release_keeps_state_consistent(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
worker_count = 8
|
||||
acquire_barrier = threading.Barrier(worker_count + 1)
|
||||
release_barrier = threading.Barrier(worker_count)
|
||||
acquired_clients: list[FakeMCPClient] = []
|
||||
acquired_lock = threading.Lock()
|
||||
|
||||
def worker() -> None:
|
||||
acquire_barrier.wait()
|
||||
client = manager.acquire(config)
|
||||
with acquired_lock:
|
||||
acquired_clients.append(client)
|
||||
release_barrier.wait()
|
||||
manager.release("shared")
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(worker_count)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
acquire_barrier.wait()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert len({id(client) for client in acquired_clients}) == 1
|
||||
assert len(FakeMCPClient.instances) == 1
|
||||
assert FakeMCPClient.instances[0].disconnect_calls == 1
|
||||
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_cleanup_all_disconnects_every_pooled_client(manager):
|
||||
manager.acquire(MCPServerConfig(name="one", transport="stdio", command="echo"))
|
||||
manager.acquire(MCPServerConfig(name="two", transport="stdio", command="echo"))
|
||||
|
||||
manager.cleanup_all()
|
||||
|
||||
assert len(FakeMCPClient.instances) == 2
|
||||
assert all(client.disconnect_calls == 1 for client in FakeMCPClient.instances)
|
||||
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._configs == {} # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_reconnect_replaces_client_even_with_existing_refcount(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
original_client = manager.acquire(config)
|
||||
manager.acquire(config)
|
||||
|
||||
replacement = manager.reconnect("shared")
|
||||
|
||||
assert replacement is not original_client
|
||||
assert original_client.disconnect_calls == 1
|
||||
assert manager._pool["shared"] is replacement # noqa: SLF001 - state assertion for unit test
|
||||
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
|
||||
|
||||
|
||||
def test_health_check_returns_false_when_server_is_unreachable(manager, monkeypatch):
|
||||
config = MCPServerConfig(name="shared", transport="http", url="http://localhost:9")
|
||||
manager.acquire(config)
|
||||
|
||||
class FailingHttpClient:
|
||||
def __init__(self, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, _path: str):
|
||||
raise httpx.ConnectError("unreachable")
|
||||
|
||||
monkeypatch.setattr("framework.runner.mcp_connection_manager.httpx.Client", FailingHttpClient)
|
||||
|
||||
assert manager.health_check("shared") is False
|
||||
|
||||
|
||||
def test_health_check_for_stdio_returns_false_on_tools_list_error(manager):
|
||||
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
|
||||
client = manager.acquire(config)
|
||||
client.list_tools_error = RuntimeError("broken")
|
||||
|
||||
assert manager.health_check("shared") is False
|
||||
@@ -0,0 +1,142 @@
|
||||
"""Tests for AS-9: Skill directory allowlisting in file-read tool interception."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.provider import ToolResult
|
||||
|
||||
|
||||
def _make_tool_call_event(tool_name: str, path: str):
|
||||
"""Build a minimal ToolCallEvent-like object."""
|
||||
tc = MagicMock()
|
||||
tc.tool_use_id = "tc-1"
|
||||
tc.tool_name = tool_name
|
||||
tc.tool_input = {"path": path}
|
||||
return tc
|
||||
|
||||
|
||||
def _make_node(skill_dirs: list[str]):
|
||||
"""Build a minimal EventLoopNode with skill_dirs set."""
|
||||
from framework.graph.event_loop_node import EventLoopNode
|
||||
|
||||
mock_result = ToolResult(tool_use_id="tc-1", content="from-executor")
|
||||
node = EventLoopNode(tool_executor=MagicMock(return_value=mock_result))
|
||||
node._skill_dirs = skill_dirs
|
||||
return node
|
||||
|
||||
|
||||
class TestSkillFileReadInterception:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reads_file_in_skill_dir(self, tmp_path):
|
||||
"""File under a skill dir is read directly, bypassing the executor."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
script = skill_dir / "scripts" / "run.py"
|
||||
script.parent.mkdir()
|
||||
script.write_text("print('hello')")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(script))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "print('hello')"
|
||||
assert not result.is_error
|
||||
node._tool_executor.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_md_read_marked_as_skill_content(self, tmp_path):
|
||||
"""Reading SKILL.md sets is_skill_content=True for AS-10 protection."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
skill_md.write_text("---\nname: my-skill\ndescription: Test\n---\nInstructions.")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(skill_md))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.is_skill_content is True
|
||||
assert not result.is_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_skill_md_resource_not_marked(self, tmp_path):
|
||||
"""Bundled resource (not SKILL.md) is NOT marked as skill_content."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
ref = skill_dir / "references" / "api.md"
|
||||
ref.parent.mkdir()
|
||||
ref.write_text("# API Reference")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("load_data", str(ref))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.is_skill_content is False
|
||||
assert not result.is_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_outside_skill_dir_goes_to_executor(self, tmp_path):
|
||||
"""Path outside skill dirs is passed through to the executor unchanged."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
other_file = tmp_path / "other" / "file.txt"
|
||||
other_file.parent.mkdir()
|
||||
other_file.write_text("other content")
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(other_file))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "from-executor"
|
||||
node._tool_executor.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_skill_dirs_goes_to_executor(self, tmp_path):
|
||||
"""When skill_dirs is empty, all tool calls go to executor."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
script = skill_dir / "scripts" / "run.py"
|
||||
script.parent.mkdir()
|
||||
script.write_text("print('hello')")
|
||||
|
||||
node = _make_node([])
|
||||
tc = _make_tool_call_event("view_file", str(script))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "from-executor"
|
||||
node._tool_executor.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_file_returns_error(self, tmp_path):
|
||||
"""Non-existent file under skill dir returns is_error=True."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
missing = skill_dir / "scripts" / "missing.py"
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("view_file", str(missing))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.is_error is True
|
||||
assert "Could not read skill resource" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_file_read_tool_goes_to_executor(self, tmp_path):
|
||||
"""Non file-read tools (e.g. web_search) bypass the interceptor."""
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
|
||||
node = _make_node([str(skill_dir)])
|
||||
tc = _make_tool_call_event("web_search", str(skill_dir / "SKILL.md"))
|
||||
|
||||
result = await node._execute_tool(tc)
|
||||
|
||||
assert result.content == "from-executor"
|
||||
node._tool_executor.assert_called_once()
|
||||
@@ -69,7 +69,13 @@ class TestSkillCatalog:
|
||||
|
||||
def test_to_prompt_xml_generation(self):
|
||||
skills = [
|
||||
_make_skill("alpha", "Alpha skill", "project", location="/p/alpha/SKILL.md"),
|
||||
_make_skill(
|
||||
"alpha",
|
||||
"Alpha skill",
|
||||
"project",
|
||||
location="/p/alpha/SKILL.md",
|
||||
base_dir="/p/alpha",
|
||||
),
|
||||
_make_skill("beta", "Beta skill", "user", location="/u/beta/SKILL.md"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
@@ -81,6 +87,7 @@ class TestSkillCatalog:
|
||||
assert "<name>beta</name>" in prompt
|
||||
assert "<description>Alpha skill</description>" in prompt
|
||||
assert "<location>/p/alpha/SKILL.md</location>" in prompt
|
||||
assert "<base_dir>/p/alpha</base_dir>" in prompt
|
||||
|
||||
def test_to_prompt_sorted_by_name(self):
|
||||
skills = [
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Tests for AS-10: Activated skill content protected from context pruning."""
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.conversation import Message, NodeConversation
|
||||
|
||||
|
||||
def _make_conversation() -> NodeConversation:
|
||||
conv = NodeConversation.__new__(NodeConversation)
|
||||
conv._messages = []
|
||||
conv._next_seq = 0
|
||||
conv._current_phase = None
|
||||
conv._store = None
|
||||
return conv
|
||||
|
||||
|
||||
async def _add_tool_msg(conv: NodeConversation, content: str, **kwargs) -> Message:
|
||||
return await conv.add_tool_result(
|
||||
tool_use_id=f"tc-{conv._next_seq}",
|
||||
content=content,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class TestSkillContentProtection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_skill_content_flag_persists(self):
|
||||
"""Message created with is_skill_content=True retains the flag."""
|
||||
conv = _make_conversation()
|
||||
msg = await _add_tool_msg(conv, "skill instructions", is_skill_content=True)
|
||||
assert msg.is_skill_content is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_message_not_marked(self):
|
||||
"""Normal tool result messages are not marked as skill content."""
|
||||
conv = _make_conversation()
|
||||
msg = await _add_tool_msg(conv, "some tool output")
|
||||
assert msg.is_skill_content is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_content_survives_prune(self):
|
||||
"""Skill content messages are skipped by prune_old_tool_results."""
|
||||
conv = _make_conversation()
|
||||
|
||||
# Add many regular tool results to push over prune threshold
|
||||
for _ in range(30):
|
||||
await _add_tool_msg(conv, "x" * 500) # ~125 tokens each
|
||||
|
||||
# Add a skill content message
|
||||
skill_msg = await _add_tool_msg(
|
||||
conv,
|
||||
"## Deep Research\n" + "instructions " * 200,
|
||||
is_skill_content=True,
|
||||
)
|
||||
|
||||
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
|
||||
|
||||
assert pruned > 0, "Expected some messages to be pruned"
|
||||
# Find the skill message — it must not be pruned
|
||||
matching = [m for m in conv._messages if m.seq == skill_msg.seq]
|
||||
assert matching, "Skill content message was removed"
|
||||
assert not matching[0].content.startswith("[Pruned tool result")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_content_can_be_pruned(self):
|
||||
"""Regular tool results are still pruned when over threshold."""
|
||||
conv = _make_conversation()
|
||||
|
||||
for _ in range(20):
|
||||
await _add_tool_msg(conv, "regular tool output " * 50)
|
||||
|
||||
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
|
||||
|
||||
assert pruned > 0, "Expected regular messages to be pruned"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_messages_also_protected(self):
|
||||
"""Existing is_error protection still works alongside is_skill_content."""
|
||||
conv = _make_conversation()
|
||||
|
||||
for _ in range(20):
|
||||
await _add_tool_msg(conv, "output " * 100)
|
||||
|
||||
err_msg = await _add_tool_msg(conv, "tool failed", is_error=True)
|
||||
|
||||
await conv.prune_old_tool_results(protect_tokens=200, min_prune_tokens=50)
|
||||
|
||||
matching = [m for m in conv._messages if m.seq == err_msg.seq]
|
||||
assert matching
|
||||
assert not matching[0].content.startswith("[Pruned tool result")
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Tests for skill system structured error codes and diagnostics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from framework.skills.skill_errors import (
|
||||
SkillError,
|
||||
SkillErrorCode,
|
||||
log_skill_error,
|
||||
)
|
||||
|
||||
|
||||
class TestSkillErrorCode:
|
||||
def test_all_codes_defined(self):
|
||||
codes = {e.value for e in SkillErrorCode}
|
||||
assert "SKILL_NOT_FOUND" in codes
|
||||
assert "SKILL_PARSE_ERROR" in codes
|
||||
assert "SKILL_ACTIVATION_FAILED" in codes
|
||||
assert "SKILL_MISSING_DESCRIPTION" in codes
|
||||
assert "SKILL_YAML_FIXUP" in codes
|
||||
assert "SKILL_NAME_MISMATCH" in codes
|
||||
assert "SKILL_COLLISION" in codes
|
||||
|
||||
|
||||
class TestSkillError:
|
||||
def test_code_stored(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what="Skill 'my-skill' not found",
|
||||
why="Not in catalog",
|
||||
fix="Check discovery paths",
|
||||
)
|
||||
assert err.code == SkillErrorCode.SKILL_NOT_FOUND
|
||||
|
||||
def test_message_format(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_MISSING_DESCRIPTION,
|
||||
what="Missing description in '/path/SKILL.md'",
|
||||
why="The description field is absent",
|
||||
fix="Add a description field to the frontmatter",
|
||||
)
|
||||
expected = (
|
||||
"[SKILL_MISSING_DESCRIPTION]\n"
|
||||
"What failed: Missing description in '/path/SKILL.md'\n"
|
||||
"Why: The description field is absent\n"
|
||||
"Fix: Add a description field to the frontmatter"
|
||||
)
|
||||
assert str(err) == expected
|
||||
|
||||
def test_is_exception(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what="Parse failed",
|
||||
why="Invalid YAML",
|
||||
fix="Fix the YAML",
|
||||
)
|
||||
assert isinstance(err, Exception)
|
||||
|
||||
def test_what_why_fix_attributes(self):
|
||||
err = SkillError(
|
||||
code=SkillErrorCode.SKILL_COLLISION,
|
||||
what="Name collision",
|
||||
why="Two skills share the same name",
|
||||
fix="Rename one skill directory",
|
||||
)
|
||||
assert err.what == "Name collision"
|
||||
assert err.why == "Two skills share the same name"
|
||||
assert err.fix == "Rename one skill directory"
|
||||
|
||||
|
||||
class TestLogSkillError:
|
||||
def test_emits_log(self, caplog):
|
||||
test_logger = logging.getLogger("test_skill")
|
||||
with caplog.at_level(logging.ERROR, logger="test_skill"):
|
||||
log_skill_error(
|
||||
test_logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what="Invalid SKILL.md at '/path'",
|
||||
why="Empty file",
|
||||
fix="Add content",
|
||||
)
|
||||
assert "SKILL_PARSE_ERROR" in caplog.text
|
||||
|
||||
def test_warning_level(self, caplog):
|
||||
test_logger = logging.getLogger("test_skill_warn")
|
||||
with caplog.at_level(logging.WARNING, logger="test_skill_warn"):
|
||||
log_skill_error(
|
||||
test_logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_YAML_FIXUP,
|
||||
what="Auto-fixed YAML",
|
||||
why="Unquoted colons",
|
||||
fix="Quote values",
|
||||
)
|
||||
assert "SKILL_YAML_FIXUP" in caplog.text
|
||||
|
||||
def test_message_contains_all_parts(self, caplog):
|
||||
test_logger = logging.getLogger("test_skill_parts")
|
||||
with caplog.at_level(logging.ERROR, logger="test_skill_parts"):
|
||||
log_skill_error(
|
||||
test_logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what="Skill not found",
|
||||
why="Not discovered",
|
||||
fix="Check paths",
|
||||
)
|
||||
assert "Skill not found" in caplog.text
|
||||
assert "Not discovered" in caplog.text
|
||||
assert "Check paths" in caplog.text
|
||||
|
||||
|
||||
class TestSkillErrorInParser:
|
||||
def test_missing_description_returns_none(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
skill_dir = tmp_path / "no-desc"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.\n", encoding="utf-8")
|
||||
result = parse_skill_md(skill_dir / "SKILL.md")
|
||||
assert result is None
|
||||
|
||||
def test_empty_file_returns_none(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
skill_dir = tmp_path / "empty"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text("", encoding="utf-8")
|
||||
result = parse_skill_md(skill_dir / "SKILL.md")
|
||||
assert result is None
|
||||
|
||||
def test_nonexistent_returns_none(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
result = parse_skill_md(tmp_path / "ghost" / "SKILL.md")
|
||||
assert result is None
|
||||
|
||||
def test_yaml_fixup_still_parses(self, tmp_path):
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
skill_dir = tmp_path / "colon-test"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: colon-test\ndescription: Use for: research\n---\nBody.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = parse_skill_md(skill_dir / "SKILL.md")
|
||||
assert result is not None
|
||||
assert "research" in result.description
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Tests for AS-6 skill resource loading support.
|
||||
|
||||
Covers:
|
||||
- <base_dir> element in catalog XML
|
||||
- allowlisted_dirs property reflects trusted skill base directories
|
||||
- skill_dirs propagation to NodeContext
|
||||
"""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
|
||||
def _make_skill(
|
||||
name: str,
|
||||
base_dir: str,
|
||||
source_scope: str = "project",
|
||||
) -> ParsedSkill:
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=f"Skill {name}",
|
||||
location=f"{base_dir}/SKILL.md",
|
||||
base_dir=base_dir,
|
||||
source_scope=source_scope,
|
||||
body="Instructions.",
|
||||
)
|
||||
|
||||
|
||||
class TestSkillResourceBaseDir:
|
||||
def test_base_dir_in_xml(self):
|
||||
"""Each community skill entry should expose its base_dir in the catalog XML."""
|
||||
skill = _make_skill("deploy", "/project/.hive/skills/deploy")
|
||||
catalog = SkillCatalog([skill])
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "<base_dir>/project/.hive/skills/deploy</base_dir>" in prompt
|
||||
|
||||
def test_base_dir_xml_escaped(self):
|
||||
"""base_dir with XML-special chars should be escaped."""
|
||||
skill = _make_skill("s", "/path/with <&> chars")
|
||||
catalog = SkillCatalog([skill])
|
||||
prompt = catalog.to_prompt()
|
||||
|
||||
assert "<base_dir>/path/with <&> chars</base_dir>" in prompt
|
||||
|
||||
def test_base_dir_absent_for_framework_skills(self):
|
||||
"""Framework-scope skills are filtered from the catalog, so no base_dir either."""
|
||||
skill = _make_skill("fw", "/hive/_default_skills/fw", source_scope="framework")
|
||||
catalog = SkillCatalog([skill])
|
||||
assert catalog.to_prompt() == ""
|
||||
|
||||
def test_allowlisted_dirs_matches_skills(self):
|
||||
"""allowlisted_dirs returns all skill base_dirs including framework ones."""
|
||||
skills = [
|
||||
_make_skill("a", "/skills/a", "project"),
|
||||
_make_skill("b", "/skills/b", "user"),
|
||||
_make_skill("c", "/skills/c", "framework"),
|
||||
]
|
||||
catalog = SkillCatalog(skills)
|
||||
dirs = catalog.allowlisted_dirs
|
||||
|
||||
assert "/skills/a" in dirs
|
||||
assert "/skills/b" in dirs
|
||||
assert "/skills/c" in dirs
|
||||
|
||||
def test_allowlisted_dirs_empty_catalog(self):
|
||||
assert SkillCatalog().allowlisted_dirs == []
|
||||
|
||||
|
||||
class TestSkillDirsPropagation:
|
||||
def _make_ctx(self, **kwargs):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.graph.node import NodeContext
|
||||
|
||||
return NodeContext(
|
||||
runtime=MagicMock(),
|
||||
node_id="n",
|
||||
node_spec=MagicMock(),
|
||||
memory={},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def test_node_context_skill_dirs_default(self):
|
||||
"""NodeContext.skill_dirs defaults to empty list."""
|
||||
ctx = self._make_ctx()
|
||||
assert ctx.skill_dirs == []
|
||||
|
||||
def test_node_context_skill_dirs_set(self):
|
||||
"""NodeContext.skill_dirs can be populated."""
|
||||
dirs = ["/skills/a", "/skills/b"]
|
||||
ctx = self._make_ctx(skill_dirs=dirs)
|
||||
assert ctx.skill_dirs == dirs
|
||||
@@ -0,0 +1,471 @@
|
||||
"""Tests for skill trust gating (AS-13)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.trust import (
|
||||
ProjectTrustClassification,
|
||||
ProjectTrustDetector,
|
||||
TrustedRepoStore,
|
||||
TrustGate,
|
||||
_is_localhost_remote,
|
||||
_normalize_remote_url,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_skill(name: str = "test-skill", scope: str = "project") -> ParsedSkill:
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description="Test skill",
|
||||
location=f"/fake/{name}/SKILL.md",
|
||||
base_dir=f"/fake/{name}",
|
||||
source_scope=scope,
|
||||
body="Test skill instructions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _normalize_remote_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeRemoteUrl:
|
||||
def test_ssh_scp_format(self):
|
||||
assert _normalize_remote_url("git@github.com:org/repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_https_format(self):
|
||||
assert _normalize_remote_url("https://github.com/org/repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_https_no_dot_git(self):
|
||||
assert _normalize_remote_url("https://github.com/org/repo") == "github.com/org/repo"
|
||||
|
||||
def test_ssh_url_format(self):
|
||||
assert _normalize_remote_url("ssh://git@github.com/org/repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_lowercased(self):
|
||||
assert _normalize_remote_url("git@GitHub.COM:Org/Repo.git") == "github.com/org/repo"
|
||||
|
||||
def test_trailing_slash_stripped(self):
|
||||
assert _normalize_remote_url("https://github.com/org/repo/") == "github.com/org/repo"
|
||||
|
||||
def test_gitlab(self):
|
||||
assert _normalize_remote_url("git@gitlab.com:team/project.git") == "gitlab.com/team/project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_localhost_remote
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsLocalhostRemote:
|
||||
def test_localhost_https(self):
|
||||
assert _is_localhost_remote("http://localhost/org/repo")
|
||||
|
||||
def test_127_0_0_1(self):
|
||||
assert _is_localhost_remote("https://127.0.0.1/repo")
|
||||
|
||||
def test_github_not_local(self):
|
||||
assert not _is_localhost_remote("https://github.com/org/repo")
|
||||
|
||||
def test_scp_localhost(self):
|
||||
assert _is_localhost_remote("git@localhost:org/repo")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrustedRepoStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrustedRepoStore:
|
||||
def test_empty_store_is_not_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
assert not store.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_trust_and_lookup(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
store.trust("github.com/org/repo", project_path="/some/path")
|
||||
assert store.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_revoke(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
store.trust("github.com/org/repo")
|
||||
assert store.revoke("github.com/org/repo")
|
||||
assert not store.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_revoke_nonexistent_returns_false(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "trusted.json")
|
||||
assert not store.revoke("github.com/nobody/nowhere")
|
||||
|
||||
def test_persists_across_instances(self, tmp_path):
|
||||
path = tmp_path / "trusted.json"
|
||||
store1 = TrustedRepoStore(path)
|
||||
store1.trust("github.com/org/repo")
|
||||
|
||||
store2 = TrustedRepoStore(path)
|
||||
assert store2.is_trusted("github.com/org/repo")
|
||||
|
||||
def test_atomic_write(self, tmp_path):
|
||||
"""Save must not leave a .tmp file behind."""
|
||||
path = tmp_path / "trusted.json"
|
||||
store = TrustedRepoStore(path)
|
||||
store.trust("github.com/org/repo")
|
||||
assert not (tmp_path / "trusted.tmp").exists()
|
||||
assert path.exists()
|
||||
|
||||
def test_corrupted_json_recovers_gracefully(self, tmp_path):
|
||||
path = tmp_path / "trusted.json"
|
||||
path.write_text("{not valid json{{", encoding="utf-8")
|
||||
store = TrustedRepoStore(path)
|
||||
assert not store.is_trusted("github.com/any/repo") # no crash
|
||||
|
||||
def test_json_schema(self, tmp_path):
|
||||
path = tmp_path / "trusted.json"
|
||||
store = TrustedRepoStore(path)
|
||||
store.trust("github.com/org/repo", project_path="/work/repo")
|
||||
data = json.loads(path.read_text())
|
||||
assert data["version"] == 1
|
||||
assert data["entries"][0]["repo_key"] == "github.com/org/repo"
|
||||
assert "added_at" in data["entries"][0]
|
||||
|
||||
def test_list_entries(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
store.trust("github.com/a/b")
|
||||
store.trust("github.com/c/d")
|
||||
entries = store.list_entries()
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ProjectTrustDetector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectTrustDetector:
|
||||
def test_none_project_dir_always_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
cls, _ = det.classify(None)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_nonexistent_dir_always_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
cls, _ = det.classify(tmp_path / "nonexistent")
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_no_git_dir_always_trusted(self, tmp_path):
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_no_remote_always_trusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
# git command returns non-zero (no remote)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=1, stdout="")
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_localhost_remote_always_trusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="http://localhost/org/repo.git\n"
|
||||
)
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_trusted_by_store(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
store.trust("github.com/trusted/repo")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="git@github.com:trusted/repo.git\n"
|
||||
)
|
||||
cls, key = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.TRUSTED_BY_USER
|
||||
assert key == "github.com/trusted/repo"
|
||||
|
||||
def test_unknown_remote_untrusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
cls, key = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.UNTRUSTED
|
||||
assert key == "github.com/stranger/repo"
|
||||
|
||||
def test_own_remotes_env_var(self, tmp_path, monkeypatch):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
monkeypatch.setenv("HIVE_OWN_REMOTES", "github.com/myorg/*")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="git@github.com:myorg/myrepo.git\n"
|
||||
)
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_git_timeout_treated_as_trusted(self, tmp_path):
|
||||
import subprocess
|
||||
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run", side_effect=subprocess.TimeoutExpired("git", 3)):
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
def test_git_not_found_treated_as_trusted(self, tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
det = ProjectTrustDetector(store)
|
||||
with patch("subprocess.run", side_effect=FileNotFoundError("git not found")):
|
||||
cls, _ = det.classify(tmp_path)
|
||||
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrustGate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrustGate:
|
||||
def test_framework_scope_always_passes(self, tmp_path):
|
||||
skill = make_skill("fw-skill", "framework")
|
||||
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
|
||||
result = gate.filter_and_gate([skill], project_dir=None)
|
||||
assert any(s.name == "fw-skill" for s in result)
|
||||
|
||||
def test_user_scope_always_passes(self, tmp_path):
|
||||
skill = make_skill("user-skill", "user")
|
||||
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
|
||||
result = gate.filter_and_gate([skill], project_dir=None)
|
||||
assert any(s.name == "user-skill" for s in result)
|
||||
|
||||
def test_no_project_skills_returns_early(self, tmp_path):
|
||||
"""When there are no project-scope skills, trust detection is skipped."""
|
||||
fw = make_skill("fw", "framework")
|
||||
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
|
||||
result = gate.filter_and_gate([fw], project_dir=tmp_path)
|
||||
assert result == [fw]
|
||||
|
||||
def test_trusted_project_skills_pass(self, tmp_path):
|
||||
"""Project skills from a trusted repo pass through."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
store.trust("github.com/trusted/repo")
|
||||
skill = make_skill("proj-skill", "project")
|
||||
gate = TrustGate(store=store, interactive=False)
|
||||
with patch("subprocess.run") as m:
|
||||
m.return_value = MagicMock(returncode=0, stdout="git@github.com:trusted/repo.git\n")
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "proj-skill" for s in result)
|
||||
|
||||
def test_untrusted_headless_skips_and_logs(self, tmp_path, caplog):
|
||||
"""In non-interactive mode, untrusted project skills are skipped."""
|
||||
import logging
|
||||
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("evil-skill", "project")
|
||||
gate = TrustGate(store=store, interactive=False)
|
||||
with patch("subprocess.run") as m:
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/evil.git\n"
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert not any(s.name == "evil-skill" for s in result)
|
||||
assert "untrusted" in caplog.text.lower() or "skipping" in caplog.text.lower()
|
||||
|
||||
def test_interactive_consent_session_only(self, tmp_path):
|
||||
"""Option 1 (session only) includes skills without writing to store."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("session-skill", "project")
|
||||
outputs = []
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=outputs.append,
|
||||
input_fn=lambda _: "1", # trust this session
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "session-skill" for s in result)
|
||||
# Must NOT persist to trusted store
|
||||
assert not store.is_trusted("github.com/stranger/repo")
|
||||
|
||||
def test_interactive_consent_permanent(self, tmp_path):
|
||||
"""Option 2 (permanent) includes skills and persists to trusted store."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("perm-skill", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: "2", # trust permanently
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "perm-skill" for s in result)
|
||||
assert store.is_trusted("github.com/stranger/repo")
|
||||
|
||||
def test_interactive_consent_deny(self, tmp_path):
|
||||
"""Option 3 (deny) excludes project skills."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("bad-skill", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: "3", # deny
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert not any(s.name == "bad-skill" for s in result)
|
||||
|
||||
def test_env_var_override_trusts_all(self, tmp_path, monkeypatch):
|
||||
"""HIVE_TRUST_PROJECT_SKILLS=1 bypasses gating entirely."""
|
||||
monkeypatch.setenv("HIVE_TRUST_PROJECT_SKILLS", "1")
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("env-skill", "project")
|
||||
gate = TrustGate(store=store, interactive=False)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert any(s.name == "env-skill" for s in result)
|
||||
|
||||
def test_keyboard_interrupt_treated_as_deny(self, tmp_path):
|
||||
"""Ctrl-C during consent prompt should deny cleanly."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("interrupted-skill", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: (_ for _ in ()).throw(KeyboardInterrupt()),
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
assert not any(s.name == "interrupted-skill" for s in result)
|
||||
|
||||
def test_security_notice_shown_once(self, tmp_path, monkeypatch):
|
||||
"""Security notice (NFR-5) should be shown the first time only."""
|
||||
# Use a temp sentinel path
|
||||
sentinel = tmp_path / ".skill_trust_notice_shown"
|
||||
monkeypatch.setattr("framework.skills.trust._NOTICE_SENTINEL_PATH", sentinel)
|
||||
assert not sentinel.exists()
|
||||
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
skill = make_skill("notice-skill", "project")
|
||||
output_lines: list[str] = []
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=output_lines.append,
|
||||
input_fn=lambda _: "3",
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
gate.filter_and_gate([skill], project_dir=tmp_path)
|
||||
|
||||
assert sentinel.exists()
|
||||
assert any("Security notice" in line for line in output_lines)
|
||||
|
||||
# Second run should NOT show the notice again
|
||||
output_lines.clear()
|
||||
skill2 = make_skill("notice-skill-2", "project")
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
gate.filter_and_gate([skill2], project_dir=tmp_path)
|
||||
|
||||
assert not any("Security notice" in line for line in output_lines)
|
||||
|
||||
def test_mixed_scopes_only_project_gated(self, tmp_path, monkeypatch):
|
||||
"""Framework and user skills should pass through even if project skills are denied."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
store = TrustedRepoStore(tmp_path / "t.json")
|
||||
fw_skill = make_skill("fw", "framework")
|
||||
user_skill = make_skill("usr", "user")
|
||||
proj_skill = make_skill("proj", "project")
|
||||
gate = TrustGate(
|
||||
store=store,
|
||||
interactive=True,
|
||||
print_fn=lambda _: None,
|
||||
input_fn=lambda _: "3", # deny project skills
|
||||
)
|
||||
with (
|
||||
patch("sys.stdin.isatty", return_value=True),
|
||||
patch("sys.stdout.isatty", return_value=True),
|
||||
patch("subprocess.run") as m,
|
||||
):
|
||||
m.return_value = MagicMock(
|
||||
returncode=0, stdout="https://github.com/stranger/repo.git\n"
|
||||
)
|
||||
result = gate.filter_and_gate([fw_skill, user_skill, proj_skill], project_dir=tmp_path)
|
||||
names = {s.name for s in result}
|
||||
assert "fw" in names
|
||||
assert "usr" in names
|
||||
assert "proj" not in names
|
||||
@@ -8,6 +8,7 @@ could cause a json.JSONDecodeError and crash execution.
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
@@ -91,3 +92,117 @@ def test_discover_from_module_handles_empty_content(tmp_path):
|
||||
result = registered.executor({})
|
||||
assert isinstance(result, dict)
|
||||
assert result == {}
|
||||
|
||||
|
||||
class _RegistryFakeClient:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.connect_calls = 0
|
||||
self.disconnect_calls = 0
|
||||
|
||||
def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.disconnect_calls += 1
|
||||
|
||||
def list_tools(self):
|
||||
return [
|
||||
SimpleNamespace(
|
||||
name="pooled_tool",
|
||||
description="Tool from MCP",
|
||||
input_schema={"type": "object", "properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
def call_tool(self, tool_name, arguments):
|
||||
return [{"text": f"{tool_name}:{arguments}"}]
|
||||
|
||||
|
||||
def test_register_mcp_server_uses_connection_manager_when_enabled(monkeypatch):
|
||||
registry = ToolRegistry()
|
||||
client = _RegistryFakeClient(SimpleNamespace(name="shared"))
|
||||
manager_calls: list[tuple[str, str]] = []
|
||||
|
||||
class FakeManager:
|
||||
def acquire(self, config):
|
||||
manager_calls.append(("acquire", config.name))
|
||||
client.config = config
|
||||
return client
|
||||
|
||||
def release(self, server_name: str) -> None:
|
||||
manager_calls.append(("release", server_name))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
|
||||
lambda: FakeManager(),
|
||||
)
|
||||
|
||||
count = registry.register_mcp_server(
|
||||
{"name": "shared", "transport": "stdio", "command": "echo"},
|
||||
use_connection_manager=True,
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
assert manager_calls == [("acquire", "shared")]
|
||||
|
||||
registry.cleanup()
|
||||
|
||||
assert manager_calls == [("acquire", "shared"), ("release", "shared")]
|
||||
assert client.disconnect_calls == 0
|
||||
|
||||
|
||||
def test_register_mcp_server_defaults_to_connection_manager(monkeypatch):
|
||||
"""Default behavior uses the connection manager (reuse enabled by default)."""
|
||||
registry = ToolRegistry()
|
||||
created_clients: list[_RegistryFakeClient] = []
|
||||
|
||||
def fake_client_factory(config):
|
||||
client = _RegistryFakeClient(config)
|
||||
created_clients.append(client)
|
||||
return client
|
||||
|
||||
class FakeManager:
|
||||
def acquire(self, config):
|
||||
return fake_client_factory(config)
|
||||
|
||||
def release(self, server_name):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
|
||||
lambda: FakeManager(),
|
||||
)
|
||||
|
||||
count = registry.register_mcp_server(
|
||||
{"name": "direct", "transport": "stdio", "command": "echo"},
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
assert len(created_clients) == 1
|
||||
|
||||
|
||||
def test_register_mcp_server_direct_client_when_manager_disabled(monkeypatch):
|
||||
"""When use_connection_manager=False, a direct MCPClient is created."""
|
||||
registry = ToolRegistry()
|
||||
created_clients: list[_RegistryFakeClient] = []
|
||||
|
||||
def fake_client_factory(config):
|
||||
client = _RegistryFakeClient(config)
|
||||
created_clients.append(client)
|
||||
return client
|
||||
|
||||
monkeypatch.setattr("framework.runner.mcp_client.MCPClient", fake_client_factory)
|
||||
|
||||
count = registry.register_mcp_server(
|
||||
{"name": "direct", "transport": "stdio", "command": "echo"},
|
||||
use_connection_manager=False,
|
||||
)
|
||||
|
||||
assert count == 1
|
||||
assert len(created_clients) == 1
|
||||
assert created_clients[0].connect_calls == 1
|
||||
|
||||
registry.cleanup()
|
||||
|
||||
assert created_clients[0].disconnect_calls == 1
|
||||
|
||||
@@ -157,7 +157,7 @@ All bounty types open in parallel. Contributors self-select. Daily progress upda
|
||||
PR merged with bounty:* label
|
||||
→ GitHub Action runs bounty-tracker.ts
|
||||
→ Calculates points from label
|
||||
→ Resolves GitHub → Discord ID via contributors.yml
|
||||
→ Resolves GitHub → Discord ID via MongoDB (hive.contributors)
|
||||
→ Pushes XP to Lurkr API
|
||||
→ Posts notification to #integrations-announcements
|
||||
```
|
||||
@@ -166,7 +166,7 @@ See the [Setup Guide](setup-guide.md) for full configuration (Lurkr, webhooks, s
|
||||
|
||||
### Identity Linking
|
||||
|
||||
Contributors link GitHub ↔ Discord by opening a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue. A GitHub Action auto-adds them to `contributors.yml` and closes the issue.
|
||||
Contributors link GitHub ↔ Discord by running `/link-github` in Discord. The bot verifies ownership via a public gist, then stores the mapping in MongoDB.
|
||||
|
||||
Without this link, bounties are still tracked but Lurkr can't push XP to your Discord account.
|
||||
|
||||
@@ -181,7 +181,7 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
|
||||
| Agent Builder role | Lurkr bot | Auto-assigned at level 5 |
|
||||
| OSS Contributor role | Lurkr bot | Auto-assigned at level 15 |
|
||||
| Core Contributor role | Maintainer | Manual (involves money) |
|
||||
| Identity linking | contributors.yml | PR-based, reviewed by maintainers |
|
||||
| Identity linking | Discord bot → MongoDB | `/link-github` command with gist verification |
|
||||
|
||||
## Guides
|
||||
|
||||
@@ -203,4 +203,4 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
|
||||
- `.github/workflows/weekly-leaderboard.yml` — Monday leaderboard post
|
||||
- `scripts/bounty-tracker.ts` — Point calculation, Lurkr API, Discord formatting
|
||||
- `scripts/setup-bounty-labels.sh` — One-time label setup
|
||||
- `contributors.yml` — GitHub ↔ Discord identity mapping
|
||||
- MongoDB `hive.contributors` — GitHub ↔ Discord identity mapping (managed by Discord bot)
|
||||
|
||||
@@ -6,9 +6,7 @@ Earn XP, Discord roles, and eventually real money by contributing to the Aden ag
|
||||
|
||||
### 1. Link your GitHub and Discord
|
||||
|
||||
Open a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue — just paste your Discord ID and submit. A GitHub Action will automatically add you to `contributors.yml` and close the issue.
|
||||
|
||||
To find your Discord ID: Discord Settings > Advanced > Enable **Developer Mode**, then right-click your name > **Copy User ID**.
|
||||
Run `/link-github your-github-username` in Discord. The bot will give you a verification code — create a public gist with that code, then run `/verify`. Done.
|
||||
|
||||
Without this link, Lurkr can't push XP to your Discord account.
|
||||
|
||||
@@ -154,7 +152,7 @@ A: Yes. Most services have free tiers. The bounty issue links to where you get t
|
||||
A: Contribute consistently across different bounty types for 4+ weeks. Maintainers will nominate you.
|
||||
|
||||
**Q: What if I haven't linked my Discord yet?**
|
||||
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Add yourself to `contributors.yml`.
|
||||
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Run `/link-github` in Discord.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
@@ -104,6 +104,8 @@ Repo Settings > Secrets and variables > Actions:
|
||||
| `DISCORD_BOUNTY_WEBHOOK_URL` | Webhook URL from Step 5 |
|
||||
| `LURKR_API_KEY` | Lurkr API key from Step 4f |
|
||||
| `LURKR_GUILD_ID` | Your Discord server ID\* |
|
||||
| `BOT_API_URL` | Discord bot API URL |
|
||||
| `BOT_API_KEY` | Discord bot API key |
|
||||
|
||||
\*Enable Developer Mode in Discord, right-click server name > Copy Server ID.
|
||||
|
||||
@@ -146,12 +148,12 @@ powerbi, redis
|
||||
- [ ] All 3 GitHub secrets added
|
||||
- [ ] Both workflows enabled (`bounty-completed.yml`, `weekly-leaderboard.yml`)
|
||||
- [ ] Test PR + merge triggers Discord notification
|
||||
- [ ] `contributors.yml` exists at repo root
|
||||
- [ ] MongoDB `hive.contributors` collection accessible
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**No Discord message:** Check `DISCORD_BOUNTY_WEBHOOK_URL` secret and Action logs.
|
||||
|
||||
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor is in `contributors.yml`, check Action logs for `Lurkr XP push failed`.
|
||||
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor has run `/link-github` in Discord, check Action logs for `Lurkr XP push failed`.
|
||||
|
||||
**Role not assigned:** Verify role rewards in the Lurkr dashboard or via `/config set`. Lurkr's role must be above the roles it assigns in server hierarchy.
|
||||
|
||||
@@ -0,0 +1,290 @@
|
||||
# Agent Skills User Guide
|
||||
|
||||
This guide covers how to use, create, and manage Agent Skills in the Hive framework. Agent Skills follow the open [Agent Skills standard](https://agentskills.io) — skills written for Claude Code, Cursor, or other compatible agents work in Hive unchanged.
|
||||
|
||||
## What are skills?
|
||||
|
||||
Skills are folders containing a `SKILL.md` file that teaches an agent how to perform a specific task. They can also bundle scripts, templates, and reference materials. Skills are loaded on demand — the agent sees a lightweight catalog at startup and pulls in full instructions only when relevant.
|
||||
|
||||
## Quick start
|
||||
|
||||
### Install a skill
|
||||
|
||||
Drop a skill folder into one of the discovery directories:
|
||||
|
||||
```bash
|
||||
# Project-level (shared with the repo)
|
||||
mkdir -p .hive/skills/my-skill
|
||||
cat > .hive/skills/my-skill/SKILL.md << 'EOF'
|
||||
---
|
||||
name: my-skill
|
||||
description: Does X when the user asks about Y.
|
||||
---
|
||||
|
||||
# My Skill
|
||||
|
||||
Step-by-step instructions for the agent...
|
||||
EOF
|
||||
```
|
||||
|
||||
The agent will discover it automatically on the next session.
|
||||
|
||||
### List discovered skills
|
||||
|
||||
```bash
|
||||
hive skill list
|
||||
```
|
||||
|
||||
Output groups skills by scope:
|
||||
|
||||
```
|
||||
PROJECT SKILLS
|
||||
────────────────────────────────────
|
||||
• my-skill
|
||||
Does X when the user asks about Y.
|
||||
/home/user/project/.hive/skills/my-skill/SKILL.md
|
||||
|
||||
USER SKILLS
|
||||
────────────────────────────────────
|
||||
• deep-research
|
||||
Multi-step web research with source verification.
|
||||
/home/user/.hive/skills/deep-research/SKILL.md
|
||||
```
|
||||
|
||||
## Where to put skills
|
||||
|
||||
Hive scans five directories at startup, in this precedence order:
|
||||
|
||||
| Scope | Path | Use case |
|
||||
|-------|------|----------|
|
||||
| Project (Hive) | `<project>/.hive/skills/` | Skills specific to this repo |
|
||||
| Project (cross-client) | `<project>/.agents/skills/` | Skills shared across Claude Code, Cursor, etc. |
|
||||
| User (Hive) | `~/.hive/skills/` | Personal skills available in all projects |
|
||||
| User (cross-client) | `~/.agents/skills/` | Personal cross-client skills |
|
||||
| Framework | *(built-in)* | Default operational skills shipped with Hive |
|
||||
|
||||
**Precedence**: If two skills share the same name, the higher-precedence location wins. A project-level `code-review` skill overrides a user-level one with the same name.
|
||||
|
||||
**Cross-client paths**: The `.agents/skills/` directories are a convention shared across compatible agents. A skill installed at `~/.agents/skills/pdf-processing/` is visible to Hive, Claude Code, Cursor, and other compatible tools simultaneously.
|
||||
|
||||
## Creating a skill
|
||||
|
||||
### Directory structure
|
||||
|
||||
```
|
||||
my-skill/
|
||||
├── SKILL.md # Required — metadata + instructions
|
||||
├── scripts/ # Optional — executable code
|
||||
│ └── run.py
|
||||
├── references/ # Optional — supplementary docs
|
||||
│ └── api-reference.md
|
||||
└── assets/ # Optional — templates, data files
|
||||
└── template.json
|
||||
```
|
||||
|
||||
### SKILL.md format
|
||||
|
||||
Every skill needs a `SKILL.md` with YAML frontmatter and a markdown body:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: my-skill
|
||||
description: Extract and summarize PDF documents. Use when the user mentions PDFs or document extraction.
|
||||
---
|
||||
|
||||
# PDF Processing
|
||||
|
||||
## When to use
|
||||
Use this skill when the user needs to extract text from PDFs or merge documents.
|
||||
|
||||
## Steps
|
||||
1. Check if pdfplumber is available...
|
||||
2. Extract text using...
|
||||
|
||||
## Edge cases
|
||||
- Scanned PDFs need OCR first...
|
||||
```
|
||||
|
||||
### Frontmatter fields
|
||||
|
||||
| Field | Required | Description |
|
||||
|-------|----------|-------------|
|
||||
| `name` | Yes | Lowercase letters, numbers, hyphens. Must match the parent directory name. Max 64 chars. |
|
||||
| `description` | Yes | What the skill does and when to use it. Max 1024 chars. Include keywords that help the agent match tasks. |
|
||||
| `license` | No | License name or reference to a bundled LICENSE file. |
|
||||
| `compatibility` | No | Environment requirements (e.g., "Requires git, docker"). |
|
||||
| `metadata` | No | Arbitrary key-value pairs (author, version, etc.). |
|
||||
| `allowed-tools` | No | Space-delimited list of pre-approved tools. |
|
||||
|
||||
### Writing good descriptions
|
||||
|
||||
The description is critical — it's what the agent uses to decide whether to activate a skill. Be specific:
|
||||
|
||||
```yaml
|
||||
# Good — tells the agent what and when
|
||||
description: Extract text and tables from PDF files, fill PDF forms, and merge multiple PDFs. Use when working with PDF documents or when the user mentions PDFs, forms, or document extraction.
|
||||
|
||||
# Bad — too vague for the agent to match
|
||||
description: Helps with PDFs.
|
||||
```
|
||||
|
||||
### Writing good instructions
|
||||
|
||||
The markdown body is loaded into the agent's context when the skill is activated. Tips:
|
||||
|
||||
- **Be procedural**: Step-by-step instructions work better than abstract descriptions.
|
||||
- **Keep it focused**: Stay under 500 lines / 5000 tokens. Move detailed reference material to `references/`.
|
||||
- **Use relative paths**: Reference bundled files with relative paths (`scripts/run.py`, `references/guide.md`).
|
||||
- **Include examples**: Show sample inputs and expected outputs.
|
||||
- **Cover edge cases**: Tell the agent what to do when things go wrong.
|
||||
|
||||
## How skills are activated
|
||||
|
||||
Skills use **progressive disclosure** — three tiers that keep context usage efficient:
|
||||
|
||||
### Tier 1: Catalog (always loaded)
|
||||
|
||||
At session start, the agent sees a compact catalog of all available skills (name + description only, ~50-100 tokens each). This is how it knows what skills exist.
|
||||
|
||||
### Tier 2: Instructions (on demand)
|
||||
|
||||
When the agent determines a skill is relevant to the current task, it reads the full `SKILL.md` body into context. This happens automatically — the agent matches the task against skill descriptions and activates the best fit.
|
||||
|
||||
### Tier 3: Resources (on demand)
|
||||
|
||||
When skill instructions reference supporting files (`scripts/extract.py`, `references/api-docs.md`), the agent reads those individually as needed.
|
||||
|
||||
### Pre-activated skills
|
||||
|
||||
Some agents are configured to load specific skills at session start (skipping the catalog phase). This is set in the agent's configuration:
|
||||
|
||||
```python
|
||||
# In agent definition
|
||||
skills = ["code-review", "deep-research"]
|
||||
```
|
||||
|
||||
Pre-activated skills have their full instructions loaded from the start, without waiting for the agent to decide they're relevant.
|
||||
|
||||
## Trust and security
|
||||
|
||||
### Why trust gating exists
|
||||
|
||||
Project-level skills come from the repository being worked on. If you clone an untrusted repo that contains a `.hive/skills/` directory, those skills could inject instructions into the agent's system prompt. Trust gating prevents this.
|
||||
|
||||
**User-level and framework skills are always trusted.** Only project-scope skills go through trust gating.
|
||||
|
||||
### What happens with untrusted project skills
|
||||
|
||||
When Hive encounters project-level skills from a repo you haven't trusted before, it shows a consent prompt:
|
||||
|
||||
```
|
||||
============================================================
|
||||
SKILL TRUST REQUIRED
|
||||
============================================================
|
||||
|
||||
The project at /home/user/new-project wants to load 2 skill(s)
|
||||
that will inject instructions into the agent's system prompt.
|
||||
Source: github.com/org/new-project
|
||||
|
||||
Skills requesting access:
|
||||
• deploy-pipeline
|
||||
"Automated deployment workflow for this project."
|
||||
/home/user/new-project/.hive/skills/deploy-pipeline/SKILL.md
|
||||
• code-standards
|
||||
"Project-specific coding standards and review checklist."
|
||||
/home/user/new-project/.hive/skills/code-standards/SKILL.md
|
||||
|
||||
Options:
|
||||
1) Trust this session only
|
||||
2) Trust permanently — remember for future runs
|
||||
3) Deny — skip all project-scope skills from this repo
|
||||
────────────────────────────────────────────────────────────
|
||||
Select option (1-3):
|
||||
```
|
||||
|
||||
### Trust a repo via CLI
|
||||
|
||||
To trust a repo permanently without the interactive prompt:
|
||||
|
||||
```bash
|
||||
hive skill trust /path/to/project
|
||||
```
|
||||
|
||||
This stores the trust decision in `~/.hive/trusted_repos.json`, keyed by the normalized git remote URL (e.g., `github.com/org/repo`).
|
||||
|
||||
### Automatic trust
|
||||
|
||||
Some repos are trusted automatically:
|
||||
|
||||
- **No git repo**: Directories without `.git/` are always trusted.
|
||||
- **No remote**: Local-only git repos (no `origin` remote) are always trusted.
|
||||
- **Localhost remotes**: Repos with `localhost`/`127.0.0.1` remotes are always trusted.
|
||||
- **Own-remote patterns**: Repos matching patterns in `~/.hive/own_remotes` or the `HIVE_OWN_REMOTES` env var are always trusted.
|
||||
|
||||
### Configure own-remote patterns
|
||||
|
||||
If you trust all repos from your organization:
|
||||
|
||||
```bash
|
||||
# Via file (one pattern per line)
|
||||
echo "github.com/my-org/*" >> ~/.hive/own_remotes
|
||||
echo "gitlab.com/my-team/*" >> ~/.hive/own_remotes
|
||||
|
||||
# Via environment variable (comma-separated)
|
||||
export HIVE_OWN_REMOTES="github.com/my-org/*,github.com/my-corp/*"
|
||||
```
|
||||
|
||||
### CI / headless environments
|
||||
|
||||
In non-interactive environments, untrusted project skills are silently skipped. To trust them explicitly:
|
||||
|
||||
```bash
|
||||
export HIVE_TRUST_PROJECT_SKILLS=1
|
||||
hive run my-agent
|
||||
```
|
||||
|
||||
## Default skills
|
||||
|
||||
Hive ships with six built-in operational skills that provide runtime resilience. These are always loaded (unless disabled) and appear as "Operational Protocols" in the agent's system prompt.
|
||||
|
||||
| Skill | Purpose |
|
||||
|-------|---------|
|
||||
| `hive.note-taking` | Structured working notes in shared memory |
|
||||
| `hive.batch-ledger` | Track per-item status in batch operations |
|
||||
| `hive.context-preservation` | Save context before context window pruning |
|
||||
| `hive.quality-monitor` | Self-assess output quality periodically |
|
||||
| `hive.error-recovery` | Structured error classification and recovery |
|
||||
| `hive.task-decomposition` | Break complex tasks into subtasks |
|
||||
|
||||
### Disable default skills
|
||||
|
||||
In your agent configuration:
|
||||
|
||||
```python
|
||||
# Disable a specific default skill
|
||||
default_skills = {
|
||||
"hive.quality-monitor": {"enabled": False},
|
||||
}
|
||||
|
||||
# Disable all default skills
|
||||
default_skills = {
|
||||
"_all": {"enabled": False},
|
||||
}
|
||||
```
|
||||
|
||||
## Environment variables
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `HIVE_TRUST_PROJECT_SKILLS=1` | Bypass trust gating for all project-level skills (CI override) |
|
||||
| `HIVE_OWN_REMOTES` | Comma-separated glob patterns for auto-trusted remotes (e.g., `github.com/myorg/*`) |
|
||||
|
||||
## Compatibility with other agents
|
||||
|
||||
Skills written for any Agent Skills-compatible agent work in Hive:
|
||||
|
||||
- Place them in `.agents/skills/` (cross-client) or `.hive/skills/` (Hive-specific).
|
||||
- The `SKILL.md` format is identical across Claude Code, Cursor, Gemini CLI, and others.
|
||||
- Skills installed at `~/.agents/skills/` are visible to all compatible agents on your machine.
|
||||
|
||||
See the [Agent Skills specification](https://agentskills.io/specification) for the full format reference.
|
||||
+14
-60
@@ -1880,6 +1880,9 @@ if ($SelectedProviderId) {
|
||||
Write-Host " -> " -NoNewline
|
||||
Write-Color -Text $SelectedModel -Color DarkGray
|
||||
}
|
||||
Write-Color -Text " To use a different model for worker agents, run:" -Color DarkGray
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text ".\scripts\setup_worker_model.ps1" -Color Cyan
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
@@ -1905,69 +1908,20 @@ if ($CodexAvailable) {
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
# Setup-only mode: show manual instructions
|
||||
# Final instructions and auto-launch
|
||||
Write-Host "API keys saved as User environment variables. New terminals pick them up automatically." -ForegroundColor DarkGray
|
||||
Write-Host "Launch anytime with " -NoNewline -ForegroundColor DarkGray
|
||||
Write-Color -Text "hive open" -Color Cyan -NoNewline
|
||||
Write-Host ". Run .\quickstart.ps1 again to reconfigure." -ForegroundColor DarkGray
|
||||
Write-Host ""
|
||||
|
||||
if ($FrontendBuilt) {
|
||||
Write-Color -Text "â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Color -Text " IMPORTANT: Restart your terminal now!" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Color -Text "â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•â•" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Host 'Environment variables (uv, API keys) are now configured, but you need to'
|
||||
Write-Host 'restart your terminal for them to take effect in new sessions.'
|
||||
Write-Host ""
|
||||
|
||||
Write-Color -Text "Run an Agent:" -Color White
|
||||
Write-Host ""
|
||||
Write-Host " Quickstart only sets things up. Launch the dashboard when you're ready:"
|
||||
Write-Color -Text " hive open" -Color Cyan
|
||||
Write-Host ""
|
||||
|
||||
if ($SelectedProviderId -or $credKey) {
|
||||
Write-Color -Text "Note:" -Color White
|
||||
Write-Host "- uv has been added to your User PATH"
|
||||
if ($SelectedProviderId -and $SelectedEnvVar) {
|
||||
Write-Host "- $SelectedEnvVar is set for LLM access"
|
||||
}
|
||||
if ($credKey) {
|
||||
Write-Host "- HIVE_CREDENTIAL_KEY is set for credential encryption"
|
||||
}
|
||||
Write-Host "- All variables will persist across reboots"
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
Write-Color -Text 'Run .\quickstart.ps1 again to reconfigure.' -Color DarkGray
|
||||
Write-Color -Text "Launching dashboard..." -Color White
|
||||
Write-Host ""
|
||||
& hive open
|
||||
} else {
|
||||
Write-Color -Text "═══════════════════════════════════════════════════════" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Color -Text " IMPORTANT: Restart your terminal now!" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Color -Text "═══════════════════════════════════════════════════════" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Host 'Environment variables (uv, API keys) are now configured, but you need to'
|
||||
Write-Host 'restart your terminal for them to take effect in new sessions.'
|
||||
Write-Host ""
|
||||
|
||||
Write-Color -Text "Run an Agent:" -Color White
|
||||
Write-Host ""
|
||||
Write-Host " Frontend build was skipped or failed. Once the dashboard is available, launch it with:"
|
||||
Write-Color -Text "Frontend build was skipped or failed." -Color Yellow -NoNewline
|
||||
Write-Host " Launch manually when ready:"
|
||||
Write-Color -Text " hive open" -Color Cyan
|
||||
Write-Host ""
|
||||
|
||||
if ($SelectedProviderId -or $credKey) {
|
||||
Write-Color -Text "Note:" -Color White
|
||||
Write-Host "- uv has been added to your User PATH"
|
||||
if ($SelectedProviderId -and $SelectedEnvVar) {
|
||||
Write-Host "- $SelectedEnvVar is set for LLM access"
|
||||
}
|
||||
if ($credKey) {
|
||||
Write-Host "- HIVE_CREDENTIAL_KEY is set for credential encryption"
|
||||
}
|
||||
Write-Host "- All variables will persist across reboots"
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
Write-Color -Text 'Run .\quickstart.ps1 again to reconfigure.' -Color DarkGray
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
+109
-37
@@ -847,7 +847,7 @@ prompt_model_selection() {
|
||||
}
|
||||
|
||||
# Function to save configuration
|
||||
# Args: provider_id env_var model max_tokens max_context_tokens [use_claude_code_sub] [api_base] [use_codex_sub]
|
||||
# Args: provider_id env_var model max_tokens max_context_tokens [use_claude_code_sub] [api_base] [use_codex_sub] [use_antigravity_sub]
|
||||
save_configuration() {
|
||||
local provider_id="$1"
|
||||
local env_var="$2"
|
||||
@@ -857,6 +857,7 @@ save_configuration() {
|
||||
local use_claude_code_sub="${6:-}"
|
||||
local api_base="${7:-}"
|
||||
local use_codex_sub="${8:-}"
|
||||
local use_antigravity_sub="${9:-}"
|
||||
|
||||
# Fallbacks if not provided
|
||||
if [ -z "$model" ]; then
|
||||
@@ -878,6 +879,7 @@ save_configuration() {
|
||||
"$use_claude_code_sub" \
|
||||
"$api_base" \
|
||||
"$use_codex_sub" \
|
||||
"$use_antigravity_sub" \
|
||||
"$(date -u +"%Y-%m-%dT%H:%M:%S+00:00")" 2>/dev/null <<'PY'
|
||||
import json
|
||||
import sys
|
||||
@@ -892,8 +894,9 @@ from pathlib import Path
|
||||
use_claude_code_sub,
|
||||
api_base,
|
||||
use_codex_sub,
|
||||
use_antigravity_sub,
|
||||
created_at,
|
||||
) = sys.argv[1:10]
|
||||
) = sys.argv[1:11]
|
||||
|
||||
cfg_path = Path.home() / ".hive" / "configuration.json"
|
||||
cfg_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -925,6 +928,23 @@ if use_codex_sub == "true":
|
||||
else:
|
||||
config["llm"].pop("use_codex_subscription", None)
|
||||
|
||||
if use_antigravity_sub == "true":
|
||||
config["llm"]["use_antigravity_subscription"] = True
|
||||
config["llm"].pop("api_key_env_var", None)
|
||||
# Store the Antigravity OAuth client secret so token refresh works
|
||||
# without hardcoding it in source code (read at runtime via config.py).
|
||||
import os as _os
|
||||
_secret = _os.environ.get("ANTIGRAVITY_CLIENT_SECRET") or ""
|
||||
if _secret:
|
||||
config["llm"]["antigravity_client_secret"] = _secret
|
||||
_client_id = _os.environ.get("ANTIGRAVITY_CLIENT_ID") or ""
|
||||
if _client_id:
|
||||
config["llm"]["antigravity_client_id"] = _client_id
|
||||
else:
|
||||
config["llm"].pop("use_antigravity_subscription", None)
|
||||
config["llm"].pop("antigravity_client_secret", None)
|
||||
config["llm"].pop("antigravity_client_id", None)
|
||||
|
||||
if api_base:
|
||||
config["llm"]["api_base"] = api_base
|
||||
else:
|
||||
@@ -993,6 +1013,17 @@ if [ -n "${HIVE_API_KEY:-}" ]; then
|
||||
HIVE_CRED_DETECTED=true
|
||||
fi
|
||||
|
||||
ANTIGRAVITY_CRED_DETECTED=false
|
||||
# Check native Antigravity IDE (macOS/Linux) SQLite state DB first
|
||||
if [ -f "$HOME/Library/Application Support/Antigravity/User/globalStorage/state.vscdb" ]; then
|
||||
ANTIGRAVITY_CRED_DETECTED=true
|
||||
elif [ -f "$HOME/.config/Antigravity/User/globalStorage/state.vscdb" ]; then
|
||||
ANTIGRAVITY_CRED_DETECTED=true
|
||||
# Native OAuth credentials
|
||||
elif [ -f "$HOME/.hive/antigravity-accounts.json" ]; then
|
||||
ANTIGRAVITY_CRED_DETECTED=true
|
||||
fi
|
||||
|
||||
# Detect API key providers
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
for env_var in "${!PROVIDER_NAMES[@]}"; do
|
||||
@@ -1035,6 +1066,8 @@ try:
|
||||
sub = "codex"
|
||||
elif llm.get("use_kimi_code_subscription"):
|
||||
sub = "kimi_code"
|
||||
elif llm.get("use_antigravity_subscription"):
|
||||
sub = "antigravity"
|
||||
elif llm.get("provider", "") == "minimax" or "api.minimax.io" in llm.get("api_base", ""):
|
||||
sub = "minimax_code"
|
||||
elif llm.get("provider", "") == "hive" or "adenhq.com" in llm.get("api_base", ""):
|
||||
@@ -1058,6 +1091,7 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
codex) [ "$CODEX_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
kimi_code) [ "$KIMI_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
hive_llm) [ "$HIVE_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
antigravity) [ "$ANTIGRAVITY_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
*)
|
||||
# API key provider — check if the env var is set
|
||||
if [ -n "$PREV_ENV_VAR" ] && [ -n "${!PREV_ENV_VAR}" ]; then
|
||||
@@ -1074,15 +1108,16 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
minimax_code) DEFAULT_CHOICE=4 ;;
|
||||
kimi_code) DEFAULT_CHOICE=5 ;;
|
||||
hive_llm) DEFAULT_CHOICE=6 ;;
|
||||
antigravity) DEFAULT_CHOICE=7 ;;
|
||||
esac
|
||||
if [ -z "$DEFAULT_CHOICE" ]; then
|
||||
case "$PREV_PROVIDER" in
|
||||
anthropic) DEFAULT_CHOICE=7 ;;
|
||||
openai) DEFAULT_CHOICE=8 ;;
|
||||
gemini) DEFAULT_CHOICE=9 ;;
|
||||
groq) DEFAULT_CHOICE=10 ;;
|
||||
cerebras) DEFAULT_CHOICE=11 ;;
|
||||
openrouter) DEFAULT_CHOICE=12 ;;
|
||||
anthropic) DEFAULT_CHOICE=8 ;;
|
||||
openai) DEFAULT_CHOICE=9 ;;
|
||||
gemini) DEFAULT_CHOICE=10 ;;
|
||||
groq) DEFAULT_CHOICE=11 ;;
|
||||
cerebras) DEFAULT_CHOICE=12 ;;
|
||||
openrouter) DEFAULT_CHOICE=13 ;;
|
||||
minimax) DEFAULT_CHOICE=4 ;;
|
||||
kimi) DEFAULT_CHOICE=5 ;;
|
||||
hive) DEFAULT_CHOICE=6 ;;
|
||||
@@ -1138,14 +1173,21 @@ else
|
||||
echo -e " ${CYAN}6)${NC} Hive LLM ${DIM}(use your Hive API key)${NC}"
|
||||
fi
|
||||
|
||||
# 7) Antigravity
|
||||
if [ "$ANTIGRAVITY_CRED_DETECTED" = true ]; then
|
||||
echo -e " ${CYAN}7)${NC} Antigravity Subscription ${DIM}(use your Google/Gemini plan)${NC} ${GREEN}(credential detected)${NC}"
|
||||
else
|
||||
echo -e " ${CYAN}7)${NC} Antigravity Subscription ${DIM}(use your Google/Gemini plan)${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e " ${CYAN}${BOLD}API key providers:${NC}"
|
||||
|
||||
# 7-12) API key providers — show (credential detected) if key already set
|
||||
# 8-13) API key providers — show (credential detected) if key already set
|
||||
PROVIDER_MENU_ENVS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GROQ_API_KEY CEREBRAS_API_KEY OPENROUTER_API_KEY)
|
||||
PROVIDER_MENU_NAMES=("Anthropic (Claude) - Recommended" "OpenAI (GPT)" "Google Gemini - Free tier available" "Groq - Fast, free tier" "Cerebras - Fast, free tier" "OpenRouter - Bring any OpenRouter model")
|
||||
for idx in "${!PROVIDER_MENU_ENVS[@]}"; do
|
||||
num=$((idx + 7))
|
||||
num=$((idx + 8))
|
||||
env_var="${PROVIDER_MENU_ENVS[$idx]}"
|
||||
if [ -n "${!env_var}" ]; then
|
||||
echo -e " ${CYAN}$num)${NC} ${PROVIDER_MENU_NAMES[$idx]} ${GREEN}(credential detected)${NC}"
|
||||
@@ -1154,7 +1196,7 @@ for idx in "${!PROVIDER_MENU_ENVS[@]}"; do
|
||||
fi
|
||||
done
|
||||
|
||||
SKIP_CHOICE=$((7 + ${#PROVIDER_MENU_ENVS[@]}))
|
||||
SKIP_CHOICE=$((8 + ${#PROVIDER_MENU_ENVS[@]}))
|
||||
echo -e " ${CYAN}$SKIP_CHOICE)${NC} Skip for now"
|
||||
echo ""
|
||||
|
||||
@@ -1297,36 +1339,75 @@ case $choice in
|
||||
echo -e " ${DIM}Model: $SELECTED_MODEL | API: ${HIVE_LLM_ENDPOINT}${NC}"
|
||||
;;
|
||||
7)
|
||||
# Antigravity Subscription
|
||||
if [ "$ANTIGRAVITY_CRED_DETECTED" = false ]; then
|
||||
echo ""
|
||||
echo -e "${CYAN} Setting up Antigravity authentication...${NC}"
|
||||
echo ""
|
||||
echo -e " ${YELLOW}A browser window will open for Google OAuth.${NC}"
|
||||
echo -e " Sign in with your Google account that has Antigravity access."
|
||||
echo ""
|
||||
|
||||
# Run native OAuth flow
|
||||
if uv run python "$SCRIPT_DIR/core/antigravity_auth.py" auth account add; then
|
||||
# Re-detect credentials
|
||||
if [ -f "$HOME/.hive/antigravity-accounts.json" ]; then
|
||||
ANTIGRAVITY_CRED_DETECTED=true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$ANTIGRAVITY_CRED_DETECTED" = false ]; then
|
||||
echo ""
|
||||
echo -e "${RED} Authentication failed or was cancelled.${NC}"
|
||||
echo ""
|
||||
SELECTED_PROVIDER_ID=""
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$ANTIGRAVITY_CRED_DETECTED" = true ]; then
|
||||
SUBSCRIPTION_MODE="antigravity"
|
||||
SELECTED_PROVIDER_ID="openai"
|
||||
SELECTED_MODEL="gemini-3-flash"
|
||||
SELECTED_MAX_TOKENS=32768
|
||||
SELECTED_MAX_CONTEXT_TOKENS=1000000 # Gemini 3 Flash — 1M context window
|
||||
echo ""
|
||||
echo -e "${YELLOW} ⚠ Using Antigravity can technically cause your account suspension. Please use at your own risk.${NC}"
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using Antigravity subscription"
|
||||
echo -e " ${DIM}Model: gemini-3-flash | Direct OAuth (no proxy required)${NC}"
|
||||
fi
|
||||
;;
|
||||
8)
|
||||
SELECTED_ENV_VAR="ANTHROPIC_API_KEY"
|
||||
SELECTED_PROVIDER_ID="anthropic"
|
||||
PROVIDER_NAME="Anthropic"
|
||||
SIGNUP_URL="https://console.anthropic.com/settings/keys"
|
||||
;;
|
||||
8)
|
||||
9)
|
||||
SELECTED_ENV_VAR="OPENAI_API_KEY"
|
||||
SELECTED_PROVIDER_ID="openai"
|
||||
PROVIDER_NAME="OpenAI"
|
||||
SIGNUP_URL="https://platform.openai.com/api-keys"
|
||||
;;
|
||||
9)
|
||||
10)
|
||||
SELECTED_ENV_VAR="GEMINI_API_KEY"
|
||||
SELECTED_PROVIDER_ID="gemini"
|
||||
PROVIDER_NAME="Google Gemini"
|
||||
SIGNUP_URL="https://aistudio.google.com/apikey"
|
||||
;;
|
||||
10)
|
||||
11)
|
||||
SELECTED_ENV_VAR="GROQ_API_KEY"
|
||||
SELECTED_PROVIDER_ID="groq"
|
||||
PROVIDER_NAME="Groq"
|
||||
SIGNUP_URL="https://console.groq.com/keys"
|
||||
;;
|
||||
11)
|
||||
12)
|
||||
SELECTED_ENV_VAR="CEREBRAS_API_KEY"
|
||||
SELECTED_PROVIDER_ID="cerebras"
|
||||
PROVIDER_NAME="Cerebras"
|
||||
SIGNUP_URL="https://cloud.cerebras.ai/"
|
||||
;;
|
||||
12)
|
||||
13)
|
||||
SELECTED_ENV_VAR="OPENROUTER_API_KEY"
|
||||
SELECTED_PROVIDER_ID="openrouter"
|
||||
SELECTED_API_BASE="https://openrouter.ai/api/v1"
|
||||
@@ -1491,6 +1572,8 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "true" "" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "codex" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "" "true" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "antigravity" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "" "" "true" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "zai_code" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "https://api.z.ai/api/coding/paas/v4" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SUBSCRIPTION_MODE" = "minimax_code" ]; then
|
||||
@@ -1767,6 +1850,8 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
else
|
||||
echo -e " ${CYAN}$SELECTED_PROVIDER_ID${NC} → ${DIM}$SELECTED_MODEL${NC}"
|
||||
fi
|
||||
echo -e " ${DIM}To use a different model for worker agents, run:${NC}"
|
||||
echo -e " ${CYAN}./scripts/setup_worker_model.sh${NC}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
@@ -1808,29 +1893,16 @@ if [ "$CODEX_AVAILABLE" = true ]; then
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo -e "${BOLD}IMPORTANT: Load your new configuration${NC}"
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo ""
|
||||
echo -e " Your API keys have been saved to ${CYAN}$SHELL_RC_FILE${NC}"
|
||||
echo -e " To use them, either:"
|
||||
echo ""
|
||||
echo -e " ${GREEN}Option 1:${NC} Source your shell config now:"
|
||||
echo -e " ${CYAN}source $SHELL_RC_FILE${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}Option 2:${NC} Open a new terminal window"
|
||||
echo ""
|
||||
echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
||||
echo -e "${DIM}API keys saved to ${CYAN}$SHELL_RC_FILE${NC}${DIM}. New terminals pick them up automatically.${NC}"
|
||||
echo -e "${DIM}Launch anytime with ${CYAN}hive open${NC}${DIM}. Run ./quickstart.sh again to reconfigure.${NC}"
|
||||
echo ""
|
||||
|
||||
echo -e "${BOLD}Run an Agent:${NC}"
|
||||
echo ""
|
||||
if [ "$FRONTEND_BUILT" = true ]; then
|
||||
echo -e " Quickstart only sets things up. Launch the dashboard when you're ready:"
|
||||
echo -e "${BOLD}Launching dashboard...${NC}"
|
||||
echo ""
|
||||
hive open
|
||||
else
|
||||
echo -e " Frontend build was skipped or failed. Once the dashboard is available, launch it with:"
|
||||
echo -e "${YELLOW}Frontend build was skipped or failed.${NC} Launch manually when ready:"
|
||||
echo -e " ${CYAN}hive open${NC}"
|
||||
echo ""
|
||||
fi
|
||||
echo -e " ${CYAN}hive open${NC}"
|
||||
echo ""
|
||||
echo -e "${DIM}Run ./quickstart.sh again to reconfigure.${NC}"
|
||||
echo ""
|
||||
|
||||
+29
-51
@@ -12,13 +12,12 @@
|
||||
* GITHUB_REPOSITORY_OWNER — e.g. "adenhq"
|
||||
* GITHUB_REPOSITORY_NAME — e.g. "hive"
|
||||
* DISCORD_WEBHOOK_URL — Discord webhook for #integrations-announcements
|
||||
* MONGODB_URI — MongoDB connection string (contributors collection)
|
||||
* LURKR_API_KEY — Lurkr Read/Write API key (for XP push)
|
||||
* LURKR_GUILD_ID — Discord server ID where Lurkr is installed
|
||||
* PR_NUMBER — (notify mode) The merged PR number
|
||||
*/
|
||||
|
||||
import { readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -151,59 +150,38 @@ async function getMergedBountyPRs(
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Identity resolution
|
||||
// Identity resolution (via bot API)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Parse contributors.yml without a YAML dependency.
|
||||
// The format is simple enough to parse with regex:
|
||||
// contributors:
|
||||
// - github: jane-doe
|
||||
// discord: "123456789012345678"
|
||||
// name: Jane Doe
|
||||
function parseContributorsYaml(raw: string): Contributor[] {
|
||||
const contributors: Contributor[] = [];
|
||||
let current: Partial<Contributor> | null = null;
|
||||
|
||||
for (const line of raw.split("\n")) {
|
||||
const trimmed = line.trim();
|
||||
|
||||
if (trimmed.startsWith("- github:")) {
|
||||
if (current?.github && current?.discord) {
|
||||
contributors.push(current as Contributor);
|
||||
}
|
||||
current = { github: trimmed.replace("- github:", "").trim() };
|
||||
} else if (trimmed.startsWith("discord:") && current) {
|
||||
current.discord = trimmed.replace("discord:", "").trim().replace(/^["']|["']$/g, "");
|
||||
} else if (trimmed.startsWith("name:") && current) {
|
||||
current.name = trimmed.replace("name:", "").trim();
|
||||
}
|
||||
}
|
||||
|
||||
// Don't forget the last entry
|
||||
if (current?.github && current?.discord) {
|
||||
contributors.push(current as Contributor);
|
||||
}
|
||||
|
||||
return contributors;
|
||||
}
|
||||
|
||||
function loadContributors(): Map<string, Contributor> {
|
||||
async function loadContributors(): Promise<Map<string, Contributor>> {
|
||||
const map = new Map<string, Contributor>();
|
||||
|
||||
try {
|
||||
// Resolve path relative to the script location (scripts/ dir → repo root)
|
||||
const scriptDir = new URL(".", import.meta.url).pathname;
|
||||
const raw = readFileSync(
|
||||
join(scriptDir, "..", "contributors.yml"),
|
||||
"utf-8"
|
||||
);
|
||||
const entries = parseContributorsYaml(raw);
|
||||
const apiUrl = process.env.BOT_API_URL;
|
||||
if (!apiUrl) {
|
||||
console.warn("Warning: BOT_API_URL not set, contributor lookups disabled");
|
||||
return map;
|
||||
}
|
||||
|
||||
for (const c of entries) {
|
||||
map.set(c.github.toLowerCase(), c);
|
||||
try {
|
||||
const headers: Record<string, string> = {};
|
||||
const apiKey = process.env.BOT_API_KEY;
|
||||
if (apiKey) {
|
||||
headers.Authorization = `Bearer ${apiKey}`;
|
||||
}
|
||||
} catch {
|
||||
console.warn("Warning: could not load contributors.yml");
|
||||
|
||||
const res = await fetch(`${apiUrl}/api/contributors`, { headers });
|
||||
if (!res.ok) {
|
||||
throw new Error(`${res.status} ${res.statusText}`);
|
||||
}
|
||||
|
||||
const docs = (await res.json()) as Contributor[];
|
||||
for (const doc of docs) {
|
||||
map.set(doc.github.toLowerCase(), doc);
|
||||
}
|
||||
|
||||
console.log(`Loaded ${map.size} contributors from bot API`);
|
||||
} catch (err) {
|
||||
console.warn(`Warning: could not load contributors from bot API: ${err}`);
|
||||
}
|
||||
|
||||
return map;
|
||||
@@ -295,7 +273,7 @@ function formatBountyNotification(bounty: BountyResult): string {
|
||||
msg += `PR: ${bounty.pr.html_url}\n`;
|
||||
|
||||
if (!bounty.discordId) {
|
||||
msg += `\n_\u{1F517} @${bounty.contributor}: link your Discord in \`contributors.yml\` to get pinged!_`;
|
||||
msg += `\n_\u{1F517} @${bounty.contributor}: use \`/link-github\` in Discord to get pinged!_`;
|
||||
}
|
||||
|
||||
return msg;
|
||||
@@ -464,7 +442,7 @@ async function main() {
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const contributors = loadContributors();
|
||||
const contributors = await loadContributors();
|
||||
|
||||
if (mode === "notify") {
|
||||
// Single bounty notification
|
||||
|
||||
@@ -33,8 +33,8 @@ OPENROUTER_SEPARATOR_TRANSLATION = str.maketrans(
|
||||
"\u2212": "-",
|
||||
"\u2044": "/",
|
||||
"\u2215": "/",
|
||||
"\u29F8": "/",
|
||||
"\uFF0F": "/",
|
||||
"\u29f8": "/",
|
||||
"\uff0f": "/",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -66,9 +66,7 @@ def _sanitize_openrouter_model_id(value: str) -> str:
|
||||
"""Sanitize pasted OpenRouter model IDs into a comparable slug."""
|
||||
normalized = unicodedata.normalize("NFKC", value or "")
|
||||
normalized = "".join(
|
||||
ch
|
||||
for ch in normalized
|
||||
if unicodedata.category(ch) not in {"Cc", "Cf"}
|
||||
ch for ch in normalized if unicodedata.category(ch) not in {"Cc", "Cf"}
|
||||
)
|
||||
normalized = normalized.translate(OPENROUTER_SEPARATOR_TRANSLATION)
|
||||
normalized = re.sub(r"\s+", "", normalized)
|
||||
@@ -183,7 +181,10 @@ def check_openrouter(
|
||||
return {"valid": False, "message": "Invalid OpenRouter API key"}
|
||||
if r.status_code == 403:
|
||||
return {"valid": False, "message": "OpenRouter API key lacks permissions"}
|
||||
return {"valid": False, "message": f"OpenRouter API returned status {r.status_code}"}
|
||||
return {
|
||||
"valid": False,
|
||||
"message": f"OpenRouter API returned status {r.status_code}",
|
||||
}
|
||||
|
||||
|
||||
def check_openrouter_model(
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Open a browser-based viewer for Hive LLM debug JSONL sessions.
|
||||
|
||||
Starts a local HTTP server and loads session data on demand (one at a time).
|
||||
|
||||
Usage:
|
||||
uv run --no-project scripts/llm_debug_log_visualizer.py
|
||||
uv run --no-project scripts/llm_debug_log_visualizer.py --no-open
|
||||
uv run --no-project scripts/llm_debug_log_visualizer.py --session <execution_id>
|
||||
uv run --no-project scripts/llm_debug_log_visualizer.py --port 8080
|
||||
uv run --no-project scripts/llm_debug_log_visualizer.py --output debug.html
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.server
|
||||
import json
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
import webbrowser
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
@@ -55,10 +59,21 @@ def _parse_args() -> argparse.Namespace:
|
||||
default=200,
|
||||
help="Maximum number of newest log files to scan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port for the local server (0 = auto-pick a free port).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-open",
|
||||
action="store_true",
|
||||
help="Generate the HTML but do not open a browser.",
|
||||
help="Start the server but do not open a browser.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-tests",
|
||||
action="store_true",
|
||||
help="Show test/mock sessions (hidden by default).",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -117,8 +132,29 @@ def _format_timestamp(raw: str) -> str:
|
||||
return raw
|
||||
|
||||
|
||||
def _is_test_session(execution_id: str, records: list[dict[str, Any]]) -> bool:
|
||||
"""Return True for sessions that look like test artifacts."""
|
||||
if execution_id.startswith("<MagicMock"):
|
||||
return True
|
||||
models = {
|
||||
str(r.get("token_counts", {}).get("model", ""))
|
||||
for r in records
|
||||
if isinstance(r.get("token_counts"), dict)
|
||||
}
|
||||
models.discard("")
|
||||
# Sessions that only used the mock LLM provider.
|
||||
if models and models <= {"mock"}:
|
||||
return True
|
||||
# Sessions with no real model at all (empty string or missing).
|
||||
if not models:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _group_sessions(
|
||||
records: list[dict[str, Any]],
|
||||
*,
|
||||
include_tests: bool = False,
|
||||
) -> tuple[list[SessionSummary], dict[str, list[dict[str, Any]]]]:
|
||||
by_session: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for record in records:
|
||||
@@ -126,6 +162,13 @@ def _group_sessions(
|
||||
if execution_id:
|
||||
by_session[execution_id].append(record)
|
||||
|
||||
if not include_tests:
|
||||
by_session = {
|
||||
eid: recs
|
||||
for eid, recs in by_session.items()
|
||||
if not _is_test_session(eid, recs)
|
||||
}
|
||||
|
||||
summaries: list[SessionSummary] = []
|
||||
for execution_id, session_records in by_session.items():
|
||||
session_records.sort(
|
||||
@@ -174,7 +217,6 @@ def _group_sessions(
|
||||
|
||||
def _render_html(
|
||||
summaries: list[SessionSummary],
|
||||
sessions: dict[str, list[dict[str, Any]]],
|
||||
initial_session_id: str,
|
||||
) -> str:
|
||||
summaries_data = [
|
||||
@@ -193,16 +235,6 @@ def _render_html(
|
||||
for summary in summaries
|
||||
]
|
||||
|
||||
sessions_data = {
|
||||
execution_id: sorted(
|
||||
records,
|
||||
key=lambda record: (
|
||||
str(record.get("timestamp", "")),
|
||||
record.get("iteration", 0),
|
||||
),
|
||||
)
|
||||
for execution_id, records in sessions.items()
|
||||
}
|
||||
initial = initial_session_id or (summaries[0].execution_id if summaries else "")
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
@@ -579,10 +611,9 @@ def _render_html(
|
||||
</div>
|
||||
|
||||
<script id="session-summaries" type="application/json">{json.dumps(summaries_data, ensure_ascii=False)}</script>
|
||||
<script id="session-records" type="application/json">{json.dumps(sessions_data, ensure_ascii=False)}</script>
|
||||
<script>
|
||||
const summaries = JSON.parse(document.getElementById("session-summaries").textContent);
|
||||
const recordsBySession = JSON.parse(document.getElementById("session-records").textContent);
|
||||
const recordCache = {{}};
|
||||
const initialSessionId = {json.dumps(initial, ensure_ascii=False)};
|
||||
|
||||
const sessionSearch = document.getElementById("sessionSearch");
|
||||
@@ -746,10 +777,18 @@ def _render_html(
|
||||
`;
|
||||
}}
|
||||
|
||||
function renderSession(sessionId) {{
|
||||
async function fetchSession(sessionId) {{
|
||||
if (recordCache[sessionId]) return recordCache[sessionId];
|
||||
const resp = await fetch(`/api/session/${{encodeURIComponent(sessionId)}}`);
|
||||
if (!resp.ok) return [];
|
||||
const data = await resp.json();
|
||||
recordCache[sessionId] = data;
|
||||
return data;
|
||||
}}
|
||||
|
||||
async function renderSession(sessionId) {{
|
||||
activeSessionId = sessionId;
|
||||
const summary = summaries.find((entry) => entry.execution_id === sessionId);
|
||||
const records = recordsBySession[sessionId] || [];
|
||||
|
||||
renderSessionChooser();
|
||||
|
||||
@@ -773,6 +812,9 @@ def _render_html(
|
||||
renderMetaCard("Source file", summary.log_file),
|
||||
].join("");
|
||||
|
||||
turnsEl.innerHTML = '<div class="empty">Loading session\u2026</div>';
|
||||
const records = await fetchSession(sessionId);
|
||||
if (activeSessionId !== sessionId) return;
|
||||
turnsEl.innerHTML = records.length
|
||||
? records.map((record) => renderTurn(record)).join("")
|
||||
: '<div class="empty">This session has no turn records.</div>';
|
||||
@@ -804,7 +846,8 @@ def _render_html(
|
||||
}});
|
||||
|
||||
const hashSession = decodeURIComponent(window.location.hash.replace(/^#/, ""));
|
||||
const bootSession = recordsBySession[hashSession] ? hashSession : activeSessionId;
|
||||
const knownIds = new Set(summaries.map((s) => s.execution_id));
|
||||
const bootSession = knownIds.has(hashSession) ? hashSession : activeSessionId;
|
||||
renderSessionChooser();
|
||||
renderSession(bootSession);
|
||||
</script>
|
||||
@@ -813,28 +856,68 @@ def _render_html(
|
||||
"""
|
||||
|
||||
|
||||
def _write_report(html_report: str, output: Path | None) -> Path:
|
||||
if output is not None:
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
output.write_text(html_report, encoding="utf-8")
|
||||
return output
|
||||
def _sort_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return sorted(
|
||||
records,
|
||||
key=lambda r: (str(r.get("timestamp", "")), r.get("iteration", 0)),
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
prefix="hive_llm_debug_",
|
||||
suffix=".html",
|
||||
delete=False,
|
||||
dir="/tmp",
|
||||
) as handle:
|
||||
handle.write(html_report)
|
||||
return Path(handle.name)
|
||||
|
||||
def _run_server(
|
||||
html: str,
|
||||
sessions: dict[str, list[dict[str, Any]]],
|
||||
port: int,
|
||||
no_open: bool,
|
||||
) -> None:
|
||||
html_bytes = html.encode("utf-8")
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None:
|
||||
if self.path == "/":
|
||||
self._respond(200, "text/html; charset=utf-8", html_bytes)
|
||||
elif self.path.startswith("/api/session/"):
|
||||
sid = urllib.parse.unquote(self.path[len("/api/session/") :])
|
||||
records = sessions.get(sid)
|
||||
if records is None:
|
||||
self._respond(404, "application/json", b"[]")
|
||||
else:
|
||||
body = json.dumps(
|
||||
_sort_records(records), ensure_ascii=False
|
||||
).encode("utf-8")
|
||||
self._respond(200, "application/json", body)
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def _respond(self, code: int, content_type: str, body: bytes) -> None:
|
||||
self.send_response(code)
|
||||
self.send_header("Content-Type", content_type)
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def log_message(self, format: str, *args: object) -> None:
|
||||
pass # silence per-request logs
|
||||
|
||||
server = http.server.HTTPServer(("127.0.0.1", port), Handler)
|
||||
actual_port = server.server_address[1]
|
||||
url = f"http://127.0.0.1:{actual_port}"
|
||||
print(f"Serving at {url} (Ctrl+C to stop)")
|
||||
|
||||
if not no_open:
|
||||
webbrowser.open(url)
|
||||
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopped.")
|
||||
finally:
|
||||
server.server_close()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = _parse_args()
|
||||
records = _discover_records(args.logs_dir.expanduser(), args.limit_files)
|
||||
summaries, sessions = _group_sessions(records)
|
||||
summaries, sessions = _group_sessions(records, include_tests=args.include_tests)
|
||||
|
||||
initial_session_id = args.session or (
|
||||
summaries[0].execution_id if summaries else ""
|
||||
@@ -843,13 +926,15 @@ def main() -> int:
|
||||
print(f"session not found: {initial_session_id}")
|
||||
return 1
|
||||
|
||||
html_report = _render_html(summaries, sessions, initial_session_id)
|
||||
output_path = _write_report(html_report, args.output)
|
||||
print(output_path)
|
||||
html_report = _render_html(summaries, initial_session_id)
|
||||
|
||||
if not args.no_open:
|
||||
webbrowser.open(output_path.resolve().as_uri())
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
args.output.write_text(html_report, encoding="utf-8")
|
||||
print(args.output)
|
||||
return 0
|
||||
|
||||
_run_server(html_report, sessions, args.port, args.no_open)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,940 @@
|
||||
#Requires -Version 5.1
|
||||
<#
|
||||
.SYNOPSIS
|
||||
setup_worker_model.ps1 - Configure a separate LLM model for worker agents
|
||||
|
||||
.DESCRIPTION
|
||||
Worker agents can use a different (e.g. cheaper/faster) model than the
|
||||
queen agent. This script writes a "worker_llm" section to
|
||||
~/.hive/configuration.json. If no worker model is configured, workers
|
||||
fall back to the default (queen) model.
|
||||
|
||||
.NOTES
|
||||
Run from the project root: .\scripts\setup_worker_model.ps1
|
||||
#>
|
||||
|
||||
$ErrorActionPreference = "Continue"
|
||||
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Definition
|
||||
$ProjectDir = Split-Path -Parent $ScriptDir
|
||||
$UvHelperPath = Join-Path $ScriptDir "uv-discovery.ps1"
|
||||
$HiveConfigDir = Join-Path $env:USERPROFILE ".hive"
|
||||
$HiveConfigFile = Join-Path $HiveConfigDir "configuration.json"
|
||||
$HiveLlmEndpoint = "https://api.adenhq.com"
|
||||
|
||||
. $UvHelperPath
|
||||
|
||||
# ============================================================
|
||||
# Colors / helpers
|
||||
# ============================================================
|
||||
|
||||
function Write-Color {
|
||||
param(
|
||||
[string]$Text,
|
||||
[ConsoleColor]$Color = [ConsoleColor]::White,
|
||||
[switch]$NoNewline
|
||||
)
|
||||
$prev = $Host.UI.RawUI.ForegroundColor
|
||||
$Host.UI.RawUI.ForegroundColor = $Color
|
||||
if ($NoNewline) { Write-Host $Text -NoNewline }
|
||||
else { Write-Host $Text }
|
||||
$Host.UI.RawUI.ForegroundColor = $prev
|
||||
}
|
||||
|
||||
function Write-Ok {
|
||||
param([string]$Text)
|
||||
Write-Color -Text "$([char]0x2B22) $Text" -Color Green
|
||||
}
|
||||
|
||||
function Write-Warn {
|
||||
param([string]$Text)
|
||||
Write-Color -Text "$([char]0x2B22) $Text" -Color Yellow
|
||||
}
|
||||
|
||||
function Write-Fail {
|
||||
param([string]$Text)
|
||||
Write-Color -Text " X $Text" -Color Red
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# Provider / model data
|
||||
# ============================================================
|
||||
|
||||
$ProviderMap = [ordered]@{
|
||||
ANTHROPIC_API_KEY = @{ Name = "Anthropic (Claude)"; Id = "anthropic" }
|
||||
OPENAI_API_KEY = @{ Name = "OpenAI (GPT)"; Id = "openai" }
|
||||
GEMINI_API_KEY = @{ Name = "Google Gemini"; Id = "gemini" }
|
||||
GOOGLE_API_KEY = @{ Name = "Google AI"; Id = "google" }
|
||||
GROQ_API_KEY = @{ Name = "Groq"; Id = "groq" }
|
||||
CEREBRAS_API_KEY = @{ Name = "Cerebras"; Id = "cerebras" }
|
||||
OPENROUTER_API_KEY = @{ Name = "OpenRouter"; Id = "openrouter" }
|
||||
MISTRAL_API_KEY = @{ Name = "Mistral"; Id = "mistral" }
|
||||
TOGETHER_API_KEY = @{ Name = "Together AI"; Id = "together" }
|
||||
DEEPSEEK_API_KEY = @{ Name = "DeepSeek"; Id = "deepseek" }
|
||||
}
|
||||
|
||||
$DefaultModels = @{
|
||||
anthropic = "claude-haiku-4-5-20251001"
|
||||
openai = "gpt-5-mini"
|
||||
gemini = "gemini-3-flash-preview"
|
||||
groq = "moonshotai/kimi-k2-instruct-0905"
|
||||
cerebras = "zai-glm-4.7"
|
||||
mistral = "mistral-large-latest"
|
||||
together_ai = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
deepseek = "deepseek-chat"
|
||||
}
|
||||
|
||||
# Model choices: array of hashtables per provider
|
||||
$ModelChoices = @{
|
||||
anthropic = @(
|
||||
@{ Id = "claude-haiku-4-5-20251001"; Label = "Haiku 4.5 - Fast + cheap (recommended)"; MaxTokens = 8192; MaxContextTokens = 180000 },
|
||||
@{ Id = "claude-sonnet-4-20250514"; Label = "Sonnet 4 - Fast + capable"; MaxTokens = 8192; MaxContextTokens = 180000 },
|
||||
@{ Id = "claude-sonnet-4-5-20250929"; Label = "Sonnet 4.5 - Best balance"; MaxTokens = 16384; MaxContextTokens = 180000 },
|
||||
@{ Id = "claude-opus-4-6"; Label = "Opus 4.6 - Most capable"; MaxTokens = 32768; MaxContextTokens = 180000 }
|
||||
)
|
||||
openai = @(
|
||||
@{ Id = "gpt-5-mini"; Label = "GPT-5 Mini - Fast + cheap (recommended)"; MaxTokens = 16384; MaxContextTokens = 120000 },
|
||||
@{ Id = "gpt-5.2"; Label = "GPT-5.2 - Most capable"; MaxTokens = 16384; MaxContextTokens = 120000 }
|
||||
)
|
||||
gemini = @(
|
||||
@{ Id = "gemini-3-flash-preview"; Label = "Gemini 3 Flash - Fast (recommended)"; MaxTokens = 8192; MaxContextTokens = 900000 },
|
||||
@{ Id = "gemini-3.1-pro-preview"; Label = "Gemini 3.1 Pro - Best quality"; MaxTokens = 8192; MaxContextTokens = 900000 }
|
||||
)
|
||||
groq = @(
|
||||
@{ Id = "moonshotai/kimi-k2-instruct-0905"; Label = "Kimi K2 - Best quality (recommended)"; MaxTokens = 8192; MaxContextTokens = 120000 },
|
||||
@{ Id = "openai/gpt-oss-120b"; Label = "GPT-OSS 120B - Fast reasoning"; MaxTokens = 8192; MaxContextTokens = 120000 }
|
||||
)
|
||||
cerebras = @(
|
||||
@{ Id = "zai-glm-4.7"; Label = "ZAI-GLM 4.7 - Best quality (recommended)"; MaxTokens = 8192; MaxContextTokens = 120000 },
|
||||
@{ Id = "qwen3-235b-a22b-instruct-2507"; Label = "Qwen3 235B - Frontier reasoning"; MaxTokens = 8192; MaxContextTokens = 120000 }
|
||||
)
|
||||
}
|
||||
|
||||
function Normalize-OpenRouterModelId {
|
||||
param([string]$ModelId)
|
||||
$normalized = if ($ModelId) { $ModelId.Trim() } else { "" }
|
||||
if ($normalized -match '(?i)^openrouter/(.+)$') {
|
||||
$normalized = $matches[1]
|
||||
}
|
||||
return $normalized
|
||||
}
|
||||
|
||||
function Get-ModelSelection {
|
||||
param([string]$ProviderId)
|
||||
|
||||
if ($ProviderId -eq "openrouter") {
|
||||
$defaultModel = ""
|
||||
if ($PrevModel -and $PrevProvider -eq $ProviderId) {
|
||||
$defaultModel = Normalize-OpenRouterModelId $PrevModel
|
||||
}
|
||||
Write-Host ""
|
||||
Write-Color -Text "Enter your OpenRouter model id:" -Color White
|
||||
Write-Color -Text " Paste from openrouter.ai (example: x-ai/grok-4.20-beta)" -Color DarkGray
|
||||
Write-Color -Text " If calls fail with guardrail/privacy errors: openrouter.ai/settings/privacy" -Color DarkGray
|
||||
Write-Host ""
|
||||
while ($true) {
|
||||
if ($defaultModel) {
|
||||
$rawModel = Read-Host "Model id [$defaultModel]"
|
||||
if ([string]::IsNullOrWhiteSpace($rawModel)) { $rawModel = $defaultModel }
|
||||
} else {
|
||||
$rawModel = Read-Host "Model id"
|
||||
}
|
||||
$normalizedModel = Normalize-OpenRouterModelId $rawModel
|
||||
if (-not [string]::IsNullOrWhiteSpace($normalizedModel)) {
|
||||
$openrouterKey = $null
|
||||
if ($SelectedEnvVar) {
|
||||
$openrouterKey = [System.Environment]::GetEnvironmentVariable($SelectedEnvVar, "Process")
|
||||
if (-not $openrouterKey) {
|
||||
$openrouterKey = [System.Environment]::GetEnvironmentVariable($SelectedEnvVar, "User")
|
||||
}
|
||||
}
|
||||
|
||||
if ($openrouterKey) {
|
||||
Write-Host " Verifying model id... " -NoNewline
|
||||
try {
|
||||
$modelApiBase = if ($SelectedApiBase) { $SelectedApiBase } else { "https://openrouter.ai/api/v1" }
|
||||
Push-Location $ProjectDir
|
||||
$hcResult = & $UvCmd run python (Join-Path $ProjectDir "scripts/check_llm_key.py") "openrouter" $openrouterKey $modelApiBase $normalizedModel 2>$null
|
||||
Pop-Location
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
if ($hcJson.model) {
|
||||
$normalizedModel = [string]$hcJson.model
|
||||
}
|
||||
Write-Color -Text "ok" -Color Green
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
Write-Host ""
|
||||
continue
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify model id (network issue). Continuing with your selection." -Color DarkGray
|
||||
}
|
||||
} catch {
|
||||
Pop-Location
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify model id (network issue). Continuing with your selection." -Color DarkGray
|
||||
}
|
||||
} else {
|
||||
Write-Color -Text " Skipping model verification (OpenRouter key not available in current shell)." -Color DarkGray
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Ok "Model: $normalizedModel"
|
||||
return @{ Model = $normalizedModel; MaxTokens = 8192; MaxContextTokens = 120000 }
|
||||
}
|
||||
Write-Color -Text "Model id cannot be empty." -Color Red
|
||||
}
|
||||
}
|
||||
|
||||
$choices = $ModelChoices[$ProviderId]
|
||||
if (-not $choices -or $choices.Count -eq 0) {
|
||||
return @{ Model = $DefaultModels[$ProviderId]; MaxTokens = 8192; MaxContextTokens = 120000 }
|
||||
}
|
||||
if ($choices.Count -eq 1) {
|
||||
return @{ Model = $choices[0].Id; MaxTokens = $choices[0].MaxTokens; MaxContextTokens = $choices[0].MaxContextTokens }
|
||||
}
|
||||
|
||||
# Find default index from previous model (if same provider)
|
||||
$defaultIdx = "1"
|
||||
if ($PrevModel -and $PrevProvider -eq $ProviderId) {
|
||||
for ($j = 0; $j -lt $choices.Count; $j++) {
|
||||
if ($choices[$j].Id -eq $PrevModel) {
|
||||
$defaultIdx = [string]($j + 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Color -Text "Select a model:" -Color White
|
||||
Write-Host ""
|
||||
for ($i = 0; $i -lt $choices.Count; $i++) {
|
||||
Write-Color -Text " $($i + 1)" -Color Cyan -NoNewline
|
||||
Write-Host ") $($choices[$i].Label) " -NoNewline
|
||||
Write-Color -Text "($($choices[$i].Id))" -Color DarkGray
|
||||
}
|
||||
Write-Host ""
|
||||
|
||||
while ($true) {
|
||||
$raw = Read-Host "Enter choice [$defaultIdx]"
|
||||
if ([string]::IsNullOrWhiteSpace($raw)) { $raw = $defaultIdx }
|
||||
if ($raw -match '^\d+$') {
|
||||
$num = [int]$raw
|
||||
if ($num -ge 1 -and $num -le $choices.Count) {
|
||||
$sel = $choices[$num - 1]
|
||||
Write-Host ""
|
||||
Write-Ok "Model: $($sel.Id)"
|
||||
return @{ Model = $sel.Id; MaxTokens = $sel.MaxTokens; MaxContextTokens = $sel.MaxContextTokens }
|
||||
}
|
||||
}
|
||||
Write-Color -Text "Invalid choice. Please enter 1-$($choices.Count)" -Color Red
|
||||
}
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# Main
|
||||
# ============================================================
|
||||
|
||||
$uvInfo = Find-Uv
|
||||
if (-not $uvInfo) {
|
||||
Write-Color -Text "uv not found. Run quickstart.ps1 first." -Color Red
|
||||
exit 1
|
||||
}
|
||||
$UvCmd = $uvInfo.Path
|
||||
|
||||
Write-Host ""
|
||||
Write-Color -Text "$([char]0x2B22) Worker Model Setup" -Color Yellow
|
||||
Write-Host ""
|
||||
Write-Color -Text "Configure a separate LLM model for worker agents." -Color DarkGray
|
||||
Write-Color -Text "Worker agents will use this model instead of the default queen model." -Color DarkGray
|
||||
Write-Host ""
|
||||
|
||||
# Show current configuration
|
||||
if (Test-Path $HiveConfigFile) {
|
||||
try {
|
||||
Push-Location $ProjectDir
|
||||
$currentConfig = & $UvCmd run python -c "
|
||||
from framework.config import get_preferred_model, get_preferred_worker_model
|
||||
print(f'Queen: {get_preferred_model()}')
|
||||
wm = get_preferred_worker_model()
|
||||
print(f'Worker: {wm if wm else chr(34) + ""(same as queen)"" + chr(34)}')
|
||||
" 2>$null
|
||||
Pop-Location
|
||||
if ($currentConfig) {
|
||||
Write-Color -Text "Current configuration:" -Color White
|
||||
foreach ($line in $currentConfig) {
|
||||
Write-Color -Text " $line" -Color DarkGray
|
||||
}
|
||||
Write-Host ""
|
||||
}
|
||||
} catch {
|
||||
Pop-Location
|
||||
}
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# Configure Worker LLM Provider
|
||||
# ============================================================
|
||||
|
||||
$SelectedProviderId = ""
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedModel = ""
|
||||
$SelectedMaxTokens = 8192
|
||||
$SelectedMaxContextTokens = 120000
|
||||
$SelectedApiBase = ""
|
||||
$SubscriptionMode = ""
|
||||
|
||||
# -- Credential detection (silent -- just set flags) ----------
|
||||
$ClaudeCredDetected = $false
|
||||
$claudeCredPath = Join-Path $env:USERPROFILE ".claude\.credentials.json"
|
||||
if (Test-Path $claudeCredPath) { $ClaudeCredDetected = $true }
|
||||
|
||||
$CodexCredDetected = $false
|
||||
$codexAuthPath = Join-Path $env:USERPROFILE ".codex\auth.json"
|
||||
if (Test-Path $codexAuthPath) { $CodexCredDetected = $true }
|
||||
|
||||
$ZaiCredDetected = $false
|
||||
$zaiKey = [System.Environment]::GetEnvironmentVariable("ZAI_API_KEY", "User")
|
||||
if (-not $zaiKey) { $zaiKey = $env:ZAI_API_KEY }
|
||||
if ($zaiKey) { $ZaiCredDetected = $true }
|
||||
|
||||
$KimiCredDetected = $false
|
||||
$kimiConfigPath = Join-Path $env:USERPROFILE ".kimi\config.toml"
|
||||
if (Test-Path $kimiConfigPath) { $KimiCredDetected = $true }
|
||||
$kimiKey = [System.Environment]::GetEnvironmentVariable("KIMI_API_KEY", "User")
|
||||
if (-not $kimiKey) { $kimiKey = $env:KIMI_API_KEY }
|
||||
if ($kimiKey) { $KimiCredDetected = $true }
|
||||
|
||||
$HiveCredDetected = $false
|
||||
$hiveKey = [System.Environment]::GetEnvironmentVariable("HIVE_API_KEY", "User")
|
||||
if (-not $hiveKey) { $hiveKey = $env:HIVE_API_KEY }
|
||||
if ($hiveKey) { $HiveCredDetected = $true }
|
||||
|
||||
# Detect API key providers
|
||||
$ProviderMenuEnvVars = @("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GROQ_API_KEY", "CEREBRAS_API_KEY", "OPENROUTER_API_KEY")
|
||||
$ProviderMenuNames = @("Anthropic (Claude) - Recommended", "OpenAI (GPT)", "Google Gemini - Free tier available", "Groq - Fast, free tier", "Cerebras - Fast, free tier", "OpenRouter - Bring any OpenRouter model")
|
||||
$ProviderMenuIds = @("anthropic", "openai", "gemini", "groq", "cerebras", "openrouter")
|
||||
$ProviderMenuUrls = @(
|
||||
"https://console.anthropic.com/settings/keys",
|
||||
"https://platform.openai.com/api-keys",
|
||||
"https://aistudio.google.com/apikey",
|
||||
"https://console.groq.com/keys",
|
||||
"https://cloud.cerebras.ai/",
|
||||
"https://openrouter.ai/keys"
|
||||
)
|
||||
|
||||
# -- Read previous worker_llm configuration (if any) ---------
|
||||
$PrevProvider = ""
|
||||
$PrevModel = ""
|
||||
$PrevEnvVar = ""
|
||||
$PrevSubMode = ""
|
||||
if (Test-Path $HiveConfigFile) {
|
||||
try {
|
||||
$prevConfig = Get-Content -Path $HiveConfigFile -Raw | ConvertFrom-Json
|
||||
$prevLlm = $prevConfig.worker_llm
|
||||
if ($prevLlm) {
|
||||
$PrevProvider = if ($prevLlm.provider) { $prevLlm.provider } else { "" }
|
||||
$PrevModel = if ($prevLlm.model) { $prevLlm.model } else { "" }
|
||||
$PrevEnvVar = if ($prevLlm.api_key_env_var) { $prevLlm.api_key_env_var } else { "" }
|
||||
if ($prevLlm.use_claude_code_subscription) { $PrevSubMode = "claude_code" }
|
||||
elseif ($prevLlm.use_codex_subscription) { $PrevSubMode = "codex" }
|
||||
elseif ($prevLlm.use_kimi_code_subscription) { $PrevSubMode = "kimi_code" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.z.ai*") { $PrevSubMode = "zai_code" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.kimi.com*") { $PrevSubMode = "kimi_code" }
|
||||
elseif ($prevLlm.provider -eq "hive" -or ($prevLlm.api_base -and $prevLlm.api_base -like "*adenhq.com*")) { $PrevSubMode = "hive_llm" }
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
# Compute default menu number (only if credential is still valid)
|
||||
$DefaultChoice = ""
|
||||
if ($PrevSubMode -or $PrevProvider) {
|
||||
$prevCredValid = $false
|
||||
switch ($PrevSubMode) {
|
||||
"claude_code" { if ($ClaudeCredDetected) { $prevCredValid = $true } }
|
||||
"zai_code" { if ($ZaiCredDetected) { $prevCredValid = $true } }
|
||||
"codex" { if ($CodexCredDetected) { $prevCredValid = $true } }
|
||||
"kimi_code" { if ($KimiCredDetected) { $prevCredValid = $true } }
|
||||
"hive_llm" { if ($HiveCredDetected) { $prevCredValid = $true } }
|
||||
default {
|
||||
if ($PrevEnvVar) {
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "Process")
|
||||
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "User") }
|
||||
if ($envVal) { $prevCredValid = $true }
|
||||
}
|
||||
}
|
||||
}
|
||||
if ($prevCredValid) {
|
||||
switch ($PrevSubMode) {
|
||||
"claude_code" { $DefaultChoice = "1" }
|
||||
"zai_code" { $DefaultChoice = "2" }
|
||||
"codex" { $DefaultChoice = "3" }
|
||||
"kimi_code" { $DefaultChoice = "4" }
|
||||
"hive_llm" { $DefaultChoice = "5" }
|
||||
}
|
||||
if (-not $DefaultChoice) {
|
||||
switch ($PrevProvider) {
|
||||
"anthropic" { $DefaultChoice = "6" }
|
||||
"openai" { $DefaultChoice = "7" }
|
||||
"gemini" { $DefaultChoice = "8" }
|
||||
"groq" { $DefaultChoice = "9" }
|
||||
"cerebras" { $DefaultChoice = "10" }
|
||||
"openrouter" { $DefaultChoice = "11" }
|
||||
"kimi" { $DefaultChoice = "4" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# -- Show unified provider selection menu ---------------------
|
||||
Write-Color -Text "Select your worker LLM provider:" -Color White
|
||||
Write-Host ""
|
||||
Write-Color -Text " Subscription modes (no API key purchase needed):" -Color Cyan
|
||||
|
||||
# 1) Claude Code
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "1" -Color Cyan -NoNewline
|
||||
Write-Host ") Claude Code Subscription " -NoNewline
|
||||
Write-Color -Text "(use your Claude Max/Pro plan)" -Color DarkGray -NoNewline
|
||||
if ($ClaudeCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
# 2) ZAI Code
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "2" -Color Cyan -NoNewline
|
||||
Write-Host ") ZAI Code Subscription " -NoNewline
|
||||
Write-Color -Text "(use your ZAI Code plan)" -Color DarkGray -NoNewline
|
||||
if ($ZaiCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
# 3) Codex
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "3" -Color Cyan -NoNewline
|
||||
Write-Host ") OpenAI Codex Subscription " -NoNewline
|
||||
Write-Color -Text "(use your Codex/ChatGPT Plus plan)" -Color DarkGray -NoNewline
|
||||
if ($CodexCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
# 4) Kimi Code
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "4" -Color Cyan -NoNewline
|
||||
Write-Host ") Kimi Code Subscription " -NoNewline
|
||||
Write-Color -Text "(use your Kimi Code plan)" -Color DarkGray -NoNewline
|
||||
if ($KimiCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
# 5) Hive LLM
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "5" -Color Cyan -NoNewline
|
||||
Write-Host ") Hive LLM " -NoNewline
|
||||
Write-Color -Text "(use your Hive API key)" -Color DarkGray -NoNewline
|
||||
if ($HiveCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
Write-Host ""
|
||||
Write-Color -Text " API key providers:" -Color Cyan
|
||||
|
||||
# 6-11) API key providers
|
||||
for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
|
||||
$num = $idx + 6
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "Process")
|
||||
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "User") }
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "$num" -Color Cyan -NoNewline
|
||||
Write-Host ") $($ProviderMenuNames[$idx])" -NoNewline
|
||||
if ($envVal) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
}
|
||||
|
||||
$SkipChoice = 6 + $ProviderMenuEnvVars.Count
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "$SkipChoice" -Color Cyan -NoNewline
|
||||
Write-Host ") Skip for now"
|
||||
Write-Host ""
|
||||
|
||||
if ($DefaultChoice) {
|
||||
Write-Color -Text " Previously configured: $PrevProvider/$PrevModel. Press Enter to keep." -Color DarkGray
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
while ($true) {
|
||||
if ($DefaultChoice) {
|
||||
$raw = Read-Host "Enter choice (1-$SkipChoice) [$DefaultChoice]"
|
||||
if ([string]::IsNullOrWhiteSpace($raw)) { $raw = $DefaultChoice }
|
||||
} else {
|
||||
$raw = Read-Host "Enter choice (1-$SkipChoice)"
|
||||
}
|
||||
if ($raw -match '^\d+$') {
|
||||
$num = [int]$raw
|
||||
if ($num -ge 1 -and $num -le $SkipChoice) { break }
|
||||
}
|
||||
Write-Color -Text "Invalid choice. Please enter 1-$SkipChoice" -Color Red
|
||||
}
|
||||
|
||||
switch ($num) {
|
||||
1 {
|
||||
# Claude Code Subscription
|
||||
if (-not $ClaudeCredDetected) {
|
||||
Write-Host ""
|
||||
Write-Warn "~/.claude/.credentials.json not found."
|
||||
Write-Host " Run 'claude' first to authenticate with your Claude subscription,"
|
||||
Write-Host " then run this script again."
|
||||
Write-Host ""
|
||||
exit 1
|
||||
}
|
||||
$SubscriptionMode = "claude_code"
|
||||
$SelectedProviderId = "anthropic"
|
||||
$SelectedModel = "claude-opus-4-6"
|
||||
$SelectedMaxTokens = 32768
|
||||
$SelectedMaxContextTokens = 180000
|
||||
Write-Host ""
|
||||
Write-Ok "Using Claude Code subscription"
|
||||
}
|
||||
2 {
|
||||
# ZAI Code Subscription
|
||||
$SubscriptionMode = "zai_code"
|
||||
$SelectedProviderId = "openai"
|
||||
$SelectedEnvVar = "ZAI_API_KEY"
|
||||
$SelectedModel = "glm-5"
|
||||
$SelectedMaxTokens = 32768
|
||||
$SelectedMaxContextTokens = 120000
|
||||
Write-Host ""
|
||||
Write-Ok "Using ZAI Code subscription"
|
||||
Write-Color -Text " Model: glm-5 | API: api.z.ai" -Color DarkGray
|
||||
}
|
||||
3 {
|
||||
# OpenAI Codex Subscription
|
||||
if (-not $CodexCredDetected) {
|
||||
Write-Host ""
|
||||
Write-Warn "Codex credentials not found. Starting OAuth login..."
|
||||
Write-Host ""
|
||||
try {
|
||||
Push-Location $ProjectDir
|
||||
& $UvCmd run python (Join-Path $ProjectDir "core\codex_oauth.py") 2>&1
|
||||
Pop-Location
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
$CodexCredDetected = $true
|
||||
} else {
|
||||
Write-Host ""
|
||||
Write-Fail "OAuth login failed or was cancelled."
|
||||
Write-Host ""
|
||||
Write-Host " Or run 'codex' to authenticate, then run this script again."
|
||||
Write-Host ""
|
||||
$SelectedProviderId = ""
|
||||
}
|
||||
} catch {
|
||||
Pop-Location
|
||||
Write-Fail "OAuth login failed: $($_.Exception.Message)"
|
||||
$SelectedProviderId = ""
|
||||
}
|
||||
}
|
||||
if ($CodexCredDetected) {
|
||||
$SubscriptionMode = "codex"
|
||||
$SelectedProviderId = "openai"
|
||||
$SelectedModel = "gpt-5.3-codex"
|
||||
$SelectedMaxTokens = 16384
|
||||
$SelectedMaxContextTokens = 120000
|
||||
Write-Host ""
|
||||
Write-Ok "Using OpenAI Codex subscription"
|
||||
}
|
||||
}
|
||||
4 {
|
||||
# Kimi Code Subscription
|
||||
$SubscriptionMode = "kimi_code"
|
||||
$SelectedProviderId = "kimi"
|
||||
$SelectedEnvVar = "KIMI_API_KEY"
|
||||
$SelectedModel = "kimi-k2.5"
|
||||
$SelectedMaxTokens = 32768
|
||||
$SelectedMaxContextTokens = 120000
|
||||
Write-Host ""
|
||||
Write-Ok "Using Kimi Code subscription"
|
||||
Write-Color -Text " Model: kimi-k2.5 | API: api.kimi.com/coding" -Color DarkGray
|
||||
}
|
||||
5 {
|
||||
# Hive LLM
|
||||
$SubscriptionMode = "hive_llm"
|
||||
$SelectedProviderId = "hive"
|
||||
$SelectedEnvVar = "HIVE_API_KEY"
|
||||
$SelectedMaxTokens = 32768
|
||||
$SelectedMaxContextTokens = 120000
|
||||
Write-Host ""
|
||||
Write-Ok "Using Hive LLM"
|
||||
Write-Host ""
|
||||
Write-Host " Select a model:"
|
||||
Write-Host " " -NoNewline; Write-Color -Text "1)" -Color Cyan -NoNewline; Write-Host " queen " -NoNewline; Write-Color -Text "(default - Hive flagship)" -Color DarkGray
|
||||
Write-Host " " -NoNewline; Write-Color -Text "2)" -Color Cyan -NoNewline; Write-Host " kimi-2.5"
|
||||
Write-Host " " -NoNewline; Write-Color -Text "3)" -Color Cyan -NoNewline; Write-Host " GLM-5"
|
||||
Write-Host ""
|
||||
$hiveModelChoice = Read-Host " Enter model choice (1-3) [1]"
|
||||
if (-not $hiveModelChoice) { $hiveModelChoice = "1" }
|
||||
switch ($hiveModelChoice) {
|
||||
"2" { $SelectedModel = "kimi-2.5" }
|
||||
"3" { $SelectedModel = "GLM-5" }
|
||||
default { $SelectedModel = "queen" }
|
||||
}
|
||||
Write-Color -Text " Model: $SelectedModel | API: $HiveLlmEndpoint" -Color DarkGray
|
||||
}
|
||||
{ $_ -ge 6 -and $_ -le 11 } {
|
||||
# API key providers
|
||||
$provIdx = $num - 6
|
||||
$SelectedEnvVar = $ProviderMenuEnvVars[$provIdx]
|
||||
$SelectedProviderId = $ProviderMenuIds[$provIdx]
|
||||
$providerName = $ProviderMenuNames[$provIdx] -replace ' - .*', '' # strip description
|
||||
$signupUrl = $ProviderMenuUrls[$provIdx]
|
||||
if ($SelectedProviderId -eq "openrouter") {
|
||||
$SelectedApiBase = "https://openrouter.ai/api/v1"
|
||||
} else {
|
||||
$SelectedApiBase = ""
|
||||
}
|
||||
|
||||
# Prompt for key (allow replacement if already set) with verification + retry
|
||||
while ($true) {
|
||||
$existingKey = [System.Environment]::GetEnvironmentVariable($SelectedEnvVar, "User")
|
||||
if (-not $existingKey) { $existingKey = [System.Environment]::GetEnvironmentVariable($SelectedEnvVar, "Process") }
|
||||
|
||||
if ($existingKey) {
|
||||
$masked = $existingKey.Substring(0, [Math]::Min(4, $existingKey.Length)) + "..." + $existingKey.Substring([Math]::Max(0, $existingKey.Length - 4))
|
||||
Write-Host ""
|
||||
Write-Color -Text " $([char]0x2B22) Current key: $masked" -Color Green
|
||||
$apiKey = Read-Host " Press Enter to keep, or paste a new key to replace"
|
||||
} else {
|
||||
Write-Host ""
|
||||
Write-Host "Get your API key from: " -NoNewline
|
||||
Write-Color -Text $signupUrl -Color Cyan
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste your $providerName API key (or press Enter to skip)"
|
||||
}
|
||||
|
||||
if ($apiKey) {
|
||||
[System.Environment]::SetEnvironmentVariable($SelectedEnvVar, $apiKey, "User")
|
||||
Set-Item -Path "Env:\$SelectedEnvVar" -Value $apiKey
|
||||
Write-Host ""
|
||||
Write-Ok "API key saved as User environment variable: $SelectedEnvVar"
|
||||
|
||||
# Health check the new key
|
||||
Write-Host " Verifying API key... " -NoNewline
|
||||
try {
|
||||
Push-Location $ProjectDir
|
||||
if ($SelectedApiBase) {
|
||||
$hcResult = & $UvCmd run python (Join-Path $ProjectDir "scripts/check_llm_key.py") $SelectedProviderId $apiKey $SelectedApiBase 2>$null
|
||||
} else {
|
||||
$hcResult = & $UvCmd run python (Join-Path $ProjectDir "scripts/check_llm_key.py") $SelectedProviderId $apiKey 2>$null
|
||||
}
|
||||
Pop-Location
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
Write-Color -Text "ok" -Color Green
|
||||
break
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
# Undo the save so user can retry cleanly
|
||||
[System.Environment]::SetEnvironmentVariable($SelectedEnvVar, $null, "User")
|
||||
Remove-Item -Path "Env:\$SelectedEnvVar" -ErrorAction SilentlyContinue
|
||||
Write-Host ""
|
||||
Read-Host " Press Enter to try again"
|
||||
# loop back to key prompt
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} catch {
|
||||
Pop-Location
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} elseif (-not $existingKey) {
|
||||
# No existing key and user skipped
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. Set the environment variable manually when ready:"
|
||||
Write-Host " [System.Environment]::SetEnvironmentVariable('$SelectedEnvVar', 'your-key', 'User')"
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedProviderId = ""
|
||||
break
|
||||
} else {
|
||||
# User pressed Enter with existing key -- keep it
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
{ $_ -eq $SkipChoice } {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. A worker LLM provider is required for worker agents."
|
||||
Write-Host " Run this script again when ready."
|
||||
Write-Host ""
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedProviderId = ""
|
||||
}
|
||||
}
|
||||
|
||||
# For ZAI subscription: prompt for API key (allow replacement if already set) with verification + retry
|
||||
if ($SubscriptionMode -eq "zai_code") {
|
||||
while ($true) {
|
||||
$existingZai = [System.Environment]::GetEnvironmentVariable("ZAI_API_KEY", "User")
|
||||
if (-not $existingZai) { $existingZai = $env:ZAI_API_KEY }
|
||||
|
||||
if ($existingZai) {
|
||||
$masked = $existingZai.Substring(0, [Math]::Min(4, $existingZai.Length)) + "..." + $existingZai.Substring([Math]::Max(0, $existingZai.Length - 4))
|
||||
Write-Host ""
|
||||
Write-Color -Text " $([char]0x2B22) Current ZAI key: $masked" -Color Green
|
||||
$apiKey = Read-Host " Press Enter to keep, or paste a new key to replace"
|
||||
} else {
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste your ZAI API key (or press Enter to skip)"
|
||||
}
|
||||
|
||||
if ($apiKey) {
|
||||
[System.Environment]::SetEnvironmentVariable("ZAI_API_KEY", $apiKey, "User")
|
||||
$env:ZAI_API_KEY = $apiKey
|
||||
Write-Host ""
|
||||
Write-Ok "ZAI API key saved as User environment variable"
|
||||
|
||||
# Health check the new key
|
||||
Write-Host " Verifying ZAI API key... " -NoNewline
|
||||
try {
|
||||
Push-Location $ProjectDir
|
||||
$hcResult = & $UvCmd run python (Join-Path $ProjectDir "scripts/check_llm_key.py") "zai" $apiKey "https://api.z.ai/api/coding/paas/v4" 2>$null
|
||||
Pop-Location
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
Write-Color -Text "ok" -Color Green
|
||||
break
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
# Undo the save so user can retry cleanly
|
||||
[System.Environment]::SetEnvironmentVariable("ZAI_API_KEY", $null, "User")
|
||||
Remove-Item -Path "Env:\ZAI_API_KEY" -ErrorAction SilentlyContinue
|
||||
Write-Host ""
|
||||
Read-Host " Press Enter to try again"
|
||||
# loop back to key prompt
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} catch {
|
||||
Pop-Location
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} elseif (-not $existingZai) {
|
||||
# No existing key and user skipped
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. Add your ZAI API key later:"
|
||||
Write-Color -Text " [System.Environment]::SetEnvironmentVariable('ZAI_API_KEY', 'your-key', 'User')" -Color Cyan
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedProviderId = ""
|
||||
$SubscriptionMode = ""
|
||||
break
|
||||
} else {
|
||||
# User pressed Enter with existing key -- keep it
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# For Kimi Code subscription: prompt for API key with verification + retry
|
||||
if ($SubscriptionMode -eq "kimi_code") {
|
||||
while ($true) {
|
||||
$existingKimi = [System.Environment]::GetEnvironmentVariable("KIMI_API_KEY", "User")
|
||||
if (-not $existingKimi) { $existingKimi = $env:KIMI_API_KEY }
|
||||
|
||||
if ($existingKimi) {
|
||||
$masked = $existingKimi.Substring(0, [Math]::Min(4, $existingKimi.Length)) + "..." + $existingKimi.Substring([Math]::Max(0, $existingKimi.Length - 4))
|
||||
Write-Host ""
|
||||
Write-Color -Text " $([char]0x2B22) Current Kimi key: $masked" -Color Green
|
||||
$apiKey = Read-Host " Press Enter to keep, or paste a new key to replace"
|
||||
} else {
|
||||
Write-Host ""
|
||||
Write-Host "Get your API key from: " -NoNewline
|
||||
Write-Color -Text "https://www.kimi.com/code" -Color Cyan
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste your Kimi API key (or press Enter to skip)"
|
||||
}
|
||||
|
||||
if ($apiKey) {
|
||||
[System.Environment]::SetEnvironmentVariable("KIMI_API_KEY", $apiKey, "User")
|
||||
$env:KIMI_API_KEY = $apiKey
|
||||
Write-Host ""
|
||||
Write-Ok "Kimi API key saved as User environment variable"
|
||||
|
||||
# Health check the new key
|
||||
Write-Host " Verifying Kimi API key... " -NoNewline
|
||||
try {
|
||||
Push-Location $ProjectDir
|
||||
$hcResult = & $UvCmd run python (Join-Path $ProjectDir "scripts/check_llm_key.py") "kimi" $apiKey "https://api.kimi.com/coding" 2>$null
|
||||
Pop-Location
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
Write-Color -Text "ok" -Color Green
|
||||
break
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
[System.Environment]::SetEnvironmentVariable("KIMI_API_KEY", $null, "User")
|
||||
Remove-Item -Path "Env:\KIMI_API_KEY" -ErrorAction SilentlyContinue
|
||||
Write-Host ""
|
||||
Read-Host " Press Enter to try again"
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} catch {
|
||||
Pop-Location
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} elseif (-not $existingKimi) {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. Add your Kimi API key later:"
|
||||
Write-Color -Text " [System.Environment]::SetEnvironmentVariable('KIMI_API_KEY', 'your-key', 'User')" -Color Cyan
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedProviderId = ""
|
||||
$SubscriptionMode = ""
|
||||
break
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# For Hive LLM: prompt for API key with verification + retry
|
||||
if ($SubscriptionMode -eq "hive_llm") {
|
||||
while ($true) {
|
||||
$existingHive = [System.Environment]::GetEnvironmentVariable("HIVE_API_KEY", "User")
|
||||
if (-not $existingHive) { $existingHive = $env:HIVE_API_KEY }
|
||||
|
||||
if ($existingHive) {
|
||||
$masked = $existingHive.Substring(0, [Math]::Min(4, $existingHive.Length)) + "..." + $existingHive.Substring([Math]::Max(0, $existingHive.Length - 4))
|
||||
Write-Host ""
|
||||
Write-Color -Text " $([char]0x2B22) Current Hive key: $masked" -Color Green
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste a new Hive API key (or press Enter to keep current)"
|
||||
} else {
|
||||
Write-Host ""
|
||||
Write-Host " Get your API key from: " -NoNewline
|
||||
Write-Color -Text "https://discord.com/invite/hQdU7QDkgR" -Color Cyan
|
||||
Write-Host ""
|
||||
$apiKey = Read-Host "Paste your Hive API key (or press Enter to skip)"
|
||||
}
|
||||
|
||||
if ($apiKey) {
|
||||
[System.Environment]::SetEnvironmentVariable("HIVE_API_KEY", $apiKey, "User")
|
||||
$env:HIVE_API_KEY = $apiKey
|
||||
Write-Host ""
|
||||
Write-Ok "Hive API key saved as User environment variable"
|
||||
|
||||
# Health check the new key
|
||||
Write-Host " Verifying Hive API key... " -NoNewline
|
||||
try {
|
||||
Push-Location $ProjectDir
|
||||
$hcResult = & $UvCmd run python (Join-Path $ProjectDir "scripts/check_llm_key.py") "hive" $apiKey "$HiveLlmEndpoint" 2>$null
|
||||
Pop-Location
|
||||
$hcJson = $hcResult | ConvertFrom-Json
|
||||
if ($hcJson.valid -eq $true) {
|
||||
Write-Color -Text "ok" -Color Green
|
||||
break
|
||||
} elseif ($hcJson.valid -eq $false) {
|
||||
Write-Color -Text "failed" -Color Red
|
||||
Write-Warn $hcJson.message
|
||||
[System.Environment]::SetEnvironmentVariable("HIVE_API_KEY", $null, "User")
|
||||
Remove-Item -Path "Env:\HIVE_API_KEY" -ErrorAction SilentlyContinue
|
||||
Write-Host ""
|
||||
Read-Host " Press Enter to try again"
|
||||
} else {
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
Write-Color -Text " Could not verify key (network issue). The key has been saved." -Color DarkGray
|
||||
break
|
||||
}
|
||||
} catch {
|
||||
Pop-Location
|
||||
Write-Color -Text "--" -Color Yellow
|
||||
break
|
||||
}
|
||||
} elseif (-not $existingHive) {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. Add your Hive API key later:"
|
||||
Write-Color -Text " [System.Environment]::SetEnvironmentVariable('HIVE_API_KEY', 'your-key', 'User')" -Color Cyan
|
||||
$SelectedEnvVar = ""
|
||||
$SelectedProviderId = ""
|
||||
$SubscriptionMode = ""
|
||||
break
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Prompt for model if not already selected (manual provider path)
|
||||
if ($SelectedProviderId -and -not $SelectedModel) {
|
||||
$modelSel = Get-ModelSelection $SelectedProviderId
|
||||
$SelectedModel = $modelSel.Model
|
||||
$SelectedMaxTokens = $modelSel.MaxTokens
|
||||
$SelectedMaxContextTokens = $modelSel.MaxContextTokens
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# Save configuration to worker_llm section
|
||||
# ============================================================
|
||||
|
||||
if ($SelectedProviderId) {
|
||||
if (-not $SelectedModel) {
|
||||
$SelectedModel = $DefaultModels[$SelectedProviderId]
|
||||
}
|
||||
Write-Host ""
|
||||
Write-Host " Saving worker model configuration... " -NoNewline
|
||||
|
||||
if (-not (Test-Path $HiveConfigDir)) {
|
||||
New-Item -ItemType Directory -Path $HiveConfigDir -Force | Out-Null
|
||||
}
|
||||
|
||||
try {
|
||||
if (Test-Path $HiveConfigFile) {
|
||||
$config = Get-Content -Path $HiveConfigFile -Raw | ConvertFrom-Json
|
||||
} else {
|
||||
$config = @{}
|
||||
}
|
||||
} catch {
|
||||
$config = @{}
|
||||
}
|
||||
|
||||
$workerLlm = @{
|
||||
provider = $SelectedProviderId
|
||||
model = $SelectedModel
|
||||
max_tokens = $SelectedMaxTokens
|
||||
max_context_tokens = $SelectedMaxContextTokens
|
||||
}
|
||||
|
||||
if ($SubscriptionMode -eq "claude_code") {
|
||||
$workerLlm["use_claude_code_subscription"] = $true
|
||||
} elseif ($SubscriptionMode -eq "codex") {
|
||||
$workerLlm["use_codex_subscription"] = $true
|
||||
} elseif ($SubscriptionMode -eq "zai_code") {
|
||||
$workerLlm["api_base"] = "https://api.z.ai/api/coding/paas/v4"
|
||||
$workerLlm["api_key_env_var"] = $SelectedEnvVar
|
||||
} elseif ($SubscriptionMode -eq "kimi_code") {
|
||||
$workerLlm["api_base"] = "https://api.kimi.com/coding"
|
||||
$workerLlm["api_key_env_var"] = $SelectedEnvVar
|
||||
} elseif ($SubscriptionMode -eq "hive_llm") {
|
||||
$workerLlm["api_base"] = $HiveLlmEndpoint
|
||||
$workerLlm["api_key_env_var"] = $SelectedEnvVar
|
||||
} elseif ($SelectedProviderId -eq "openrouter") {
|
||||
$workerLlm["api_base"] = "https://openrouter.ai/api/v1"
|
||||
$workerLlm["api_key_env_var"] = $SelectedEnvVar
|
||||
} else {
|
||||
$workerLlm["api_key_env_var"] = $SelectedEnvVar
|
||||
}
|
||||
|
||||
$config | Add-Member -NotePropertyName "worker_llm" -NotePropertyValue $workerLlm -Force
|
||||
$config | ConvertTo-Json -Depth 4 | Set-Content -Path $HiveConfigFile -Encoding UTF8
|
||||
Write-Ok "done"
|
||||
Write-Color -Text " ~/.hive/configuration.json (worker_llm section)" -Color DarkGray
|
||||
|
||||
Write-Host ""
|
||||
Write-Ok "Worker model configured successfully."
|
||||
Write-Color -Text " Worker agents will now use: $SelectedProviderId/$SelectedModel" -Color DarkGray
|
||||
Write-Color -Text " Run this script again to change, or remove the worker_llm section" -Color DarkGray
|
||||
Write-Color -Text " from ~/.hive/configuration.json to revert to the default." -Color DarkGray
|
||||
Write-Host ""
|
||||
}
|
||||
Executable
+1196
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ def get_secure_path(path: str, workspace_id: str, agent_id: str, session_id: str
|
||||
raise ValueError("workspace_id, agent_id, and session_id are all required")
|
||||
|
||||
# Ensure session directory exists
|
||||
session_dir = os.path.abspath(os.path.join(WORKSPACES_DIR, workspace_id, agent_id, session_id))
|
||||
session_dir = os.path.realpath(os.path.join(WORKSPACES_DIR, workspace_id, agent_id, session_id))
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
# Normalize whitespace to prevent bypass via leading spaces/tabs
|
||||
@@ -21,9 +21,9 @@ def get_secure_path(path: str, workspace_id: str, agent_id: str, session_id: str
|
||||
# Strip exactly one leading separator to make path relative to session_dir,
|
||||
# preserving any subsequent separators (e.g. UNC paths like //server/share)
|
||||
rel_path = path[1:] if path and path[0] in ("/", "\\") else path
|
||||
final_path = os.path.abspath(os.path.join(session_dir, rel_path))
|
||||
final_path = os.path.realpath(os.path.join(session_dir, rel_path))
|
||||
else:
|
||||
final_path = os.path.abspath(os.path.join(session_dir, path))
|
||||
final_path = os.path.realpath(os.path.join(session_dir, path))
|
||||
|
||||
# Verify path is within session_dir
|
||||
try:
|
||||
|
||||
@@ -391,7 +391,7 @@ def register_tools(
|
||||
def google_sheets_update_values(
|
||||
spreadsheet_id: str,
|
||||
range_name: str,
|
||||
values: list[list[Any]],
|
||||
values: list[list[Any]] | str,
|
||||
value_input_option: str = "USER_ENTERED",
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
@@ -405,16 +405,29 @@ def register_tools(
|
||||
Args:
|
||||
spreadsheet_id: The spreadsheet ID (from the URL)
|
||||
range_name: The A1 notation range (e.g., "Sheet1!A1:B10")
|
||||
values: 2D array of values to write
|
||||
values: 2D array of values to write. Accepts a list or a JSON string.
|
||||
value_input_option: How to interpret input
|
||||
(USER_ENTERED parses, RAW stores as-is)
|
||||
|
||||
Returns:
|
||||
Dict with update result or error
|
||||
"""
|
||||
# Credentials check first so missing-creds errors aren't masked
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
# Accept stringified JSON and deserialize
|
||||
import json
|
||||
|
||||
if isinstance(values, str):
|
||||
try:
|
||||
values = json.loads(values)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return {"error": "values is not valid JSON"}
|
||||
if not isinstance(values, list):
|
||||
return {
|
||||
"error": f"values must be a 2D list or JSON string, got {type(values).__name__}"
|
||||
}
|
||||
try:
|
||||
return client.update_values(spreadsheet_id, range_name, values, value_input_option)
|
||||
except httpx.TimeoutException:
|
||||
@@ -426,7 +439,7 @@ def register_tools(
|
||||
def google_sheets_append_values(
|
||||
spreadsheet_id: str,
|
||||
range_name: str,
|
||||
values: list[list[Any]],
|
||||
values: list[list[Any]] | str,
|
||||
value_input_option: str = "USER_ENTERED",
|
||||
# Tracking parameters (injected by framework, ignored by tool)
|
||||
workspace_id: str | None = None,
|
||||
@@ -440,16 +453,29 @@ def register_tools(
|
||||
Args:
|
||||
spreadsheet_id: The spreadsheet ID (from the URL)
|
||||
range_name: The A1 notation range (e.g., "Sheet1!A1")
|
||||
values: 2D array of values to append
|
||||
values: 2D array of values to append. Accepts a list or a JSON string.
|
||||
value_input_option: How to interpret input
|
||||
(USER_ENTERED parses, RAW stores as-is)
|
||||
|
||||
Returns:
|
||||
Dict with append result or error
|
||||
"""
|
||||
# Credentials check first so missing-creds errors aren't masked
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
# Accept stringified JSON and deserialize
|
||||
import json
|
||||
|
||||
if isinstance(values, str):
|
||||
try:
|
||||
values = json.loads(values)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return {"error": "values is not valid JSON"}
|
||||
if not isinstance(values, list):
|
||||
return {
|
||||
"error": f"values must be a 2D list or JSON string, got {type(values).__name__}"
|
||||
}
|
||||
try:
|
||||
return client.append_values(spreadsheet_id, range_name, values, value_input_option)
|
||||
except httpx.TimeoutException:
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
PDF Read Tool - Manage Accounting and Financial Operations.
|
||||
|
||||
Uses pypdf to read PDF documents and extract text content
|
||||
along with metadata.
|
||||
along with metadata. Supports both local file paths and URLs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
from pypdf import PdfReader
|
||||
|
||||
@@ -98,9 +100,10 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
|
||||
Returns text content with page markers and optional metadata.
|
||||
Use for reading PDFs, reports, documents, or any PDF file.
|
||||
Supports both local file paths and URLs.
|
||||
|
||||
Args:
|
||||
file_path: Path to the PDF file to read (absolute or relative)
|
||||
file_path: Path or URL to the PDF file (local path, or http/https URL)
|
||||
pages: Page range - 'all'/None for all, '5' for single,
|
||||
'1-10' for range, '1,3,5' for specific
|
||||
max_pages: Maximum number of pages to process (1-1000, memory safety)
|
||||
@@ -109,8 +112,48 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
Returns:
|
||||
Dict with extracted text and metadata, or error dict
|
||||
"""
|
||||
temp_file = None
|
||||
try:
|
||||
path = Path(file_path).resolve()
|
||||
# Check if input is a URL
|
||||
is_url = file_path.startswith(("http://", "https://"))
|
||||
|
||||
if is_url:
|
||||
# Download PDF from URL to temporary file
|
||||
try:
|
||||
response = httpx.get(
|
||||
file_path,
|
||||
headers={"User-Agent": "AdenBot/1.0 (PDF Reader)"},
|
||||
follow_redirects=True,
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {"error": f"Failed to download PDF: HTTP {response.status_code}"}
|
||||
|
||||
# Validate content-type
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
if "application/pdf" not in content_type:
|
||||
return {
|
||||
"error": (
|
||||
f"URL does not point to a PDF file. Content-Type: {content_type}"
|
||||
),
|
||||
"content_type": content_type,
|
||||
"url": file_path,
|
||||
}
|
||||
|
||||
# Save to temporary file
|
||||
temp_file = tempfile.NamedTemporaryFile(mode="wb", suffix=".pdf", delete=False)
|
||||
temp_file.write(response.content)
|
||||
temp_file.close()
|
||||
path = Path(temp_file.name)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "PDF download timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Failed to download PDF: {str(e)}"}
|
||||
else:
|
||||
# Local file path
|
||||
path = Path(file_path).resolve()
|
||||
|
||||
# Validate file exists
|
||||
if not path.exists():
|
||||
@@ -192,3 +235,10 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
return {"error": f"Permission denied: {file_path}"}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to read PDF: {str(e)}"}
|
||||
finally:
|
||||
# Clean up temporary file if it was created
|
||||
if temp_file is not None:
|
||||
try:
|
||||
Path(temp_file.name).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Tests for pdf_read tool (FastMCP)."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
@@ -111,3 +113,175 @@ class TestPdfReadTool:
|
||||
# New behavior: explicit truncation metadata instead of silent truncation
|
||||
assert result.get("truncated") is True
|
||||
assert "truncation_warning" in result
|
||||
|
||||
|
||||
class TestPdfReadUrlSupport:
|
||||
"""Tests for URL download support in pdf_read tool."""
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("aden_tools.tools.pdf_read_tool.pdf_read_tool.PdfReader")
|
||||
def test_url_download_succeeds(self, mock_pdf_reader, mock_get, pdf_read_fn):
|
||||
"""Valid PDF URL downloads and parses successfully."""
|
||||
# Mock HTTP response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/pdf"}
|
||||
mock_response.content = b"%PDF-1.4\nfake pdf content"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Mock PdfReader
|
||||
mock_reader_instance = MagicMock()
|
||||
mock_reader_instance.is_encrypted = False
|
||||
mock_reader_instance.pages = [MagicMock()]
|
||||
mock_reader_instance.pages[0].extract_text.return_value = "PDF text content"
|
||||
mock_reader_instance.metadata = None
|
||||
mock_pdf_reader.return_value = mock_reader_instance
|
||||
|
||||
result = pdf_read_fn(file_path="https://example.com/document.pdf")
|
||||
|
||||
assert "error" not in result
|
||||
assert "content" in result
|
||||
assert "PDF text content" in result["content"]
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch("httpx.get")
|
||||
def test_url_non_pdf_content_type(self, mock_get, pdf_read_fn):
|
||||
"""URL returning non-PDF content-type returns error."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "text/html"}
|
||||
mock_response.content = b"<html>Not a PDF</html>"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = pdf_read_fn(file_path="https://example.com/page.html")
|
||||
|
||||
assert "error" in result
|
||||
assert "does not point to a pdf" in result["error"].lower()
|
||||
assert "content_type" in result
|
||||
assert "text/html" in result["content_type"]
|
||||
|
||||
@patch("httpx.get")
|
||||
def test_url_http_404_error(self, mock_get, pdf_read_fn):
|
||||
"""URL returning 404 returns appropriate error."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = pdf_read_fn(file_path="https://example.com/missing.pdf")
|
||||
|
||||
assert "error" in result
|
||||
assert "404" in result["error"]
|
||||
|
||||
@patch("httpx.get")
|
||||
def test_url_http_500_error(self, mock_get, pdf_read_fn):
|
||||
"""URL returning 500 returns appropriate error."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = pdf_read_fn(file_path="https://example.com/error.pdf")
|
||||
|
||||
assert "error" in result
|
||||
assert "500" in result["error"]
|
||||
|
||||
@patch("httpx.get")
|
||||
def test_url_timeout_error(self, mock_get, pdf_read_fn):
|
||||
"""URL request timeout returns appropriate error."""
|
||||
mock_get.side_effect = httpx.TimeoutException("Timeout")
|
||||
|
||||
result = pdf_read_fn(file_path="https://example.com/slow.pdf")
|
||||
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"].lower()
|
||||
|
||||
@patch("httpx.get")
|
||||
def test_url_network_error(self, mock_get, pdf_read_fn):
|
||||
"""Network error returns appropriate error."""
|
||||
mock_get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
result = pdf_read_fn(file_path="https://example.com/doc.pdf")
|
||||
|
||||
assert "error" in result
|
||||
assert "failed to download" in result["error"].lower()
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("aden_tools.tools.pdf_read_tool.pdf_read_tool.PdfReader")
|
||||
def test_url_with_http_scheme(self, mock_pdf_reader, mock_get, pdf_read_fn):
|
||||
"""HTTP URLs (not HTTPS) are handled correctly."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/pdf"}
|
||||
mock_response.content = b"%PDF-1.4\ncontent"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
mock_reader_instance = MagicMock()
|
||||
mock_reader_instance.is_encrypted = False
|
||||
mock_reader_instance.pages = [MagicMock()]
|
||||
mock_reader_instance.pages[0].extract_text.return_value = "Text"
|
||||
mock_reader_instance.metadata = None
|
||||
mock_pdf_reader.return_value = mock_reader_instance
|
||||
|
||||
result = pdf_read_fn(file_path="http://example.com/doc.pdf")
|
||||
|
||||
assert "error" not in result
|
||||
mock_get.assert_called_once()
|
||||
|
||||
def test_local_file_path_still_works(self, pdf_read_fn, tmp_path: Path):
|
||||
"""Local file paths still work (backward compatibility)."""
|
||||
pdf_file = tmp_path / "local.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4")
|
||||
|
||||
result = pdf_read_fn(file_path=str(pdf_file))
|
||||
|
||||
# Will error due to invalid PDF, but should not treat as URL
|
||||
assert isinstance(result, dict)
|
||||
# Should not have URL-specific errors
|
||||
if "error" in result:
|
||||
assert "download" not in result["error"].lower()
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("aden_tools.tools.pdf_read_tool.pdf_read_tool.PdfReader")
|
||||
@patch("aden_tools.tools.pdf_read_tool.pdf_read_tool.tempfile.NamedTemporaryFile")
|
||||
def test_temporary_file_cleanup(self, mock_tempfile, mock_pdf_reader, mock_get, pdf_read_fn):
|
||||
"""Temporary file is cleaned up after processing."""
|
||||
# Mock HTTP response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/pdf"}
|
||||
mock_response.content = b"%PDF-1.4\ncontent"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Mock temporary file
|
||||
mock_temp = MagicMock()
|
||||
mock_temp.name = "/tmp/test.pdf"
|
||||
mock_tempfile.return_value = mock_temp
|
||||
|
||||
# Mock PdfReader
|
||||
mock_reader_instance = MagicMock()
|
||||
mock_reader_instance.is_encrypted = False
|
||||
mock_reader_instance.pages = [MagicMock()]
|
||||
mock_reader_instance.pages[0].extract_text.return_value = "Text"
|
||||
mock_reader_instance.metadata = None
|
||||
mock_pdf_reader.return_value = mock_reader_instance
|
||||
|
||||
pdf_read_fn(file_path="https://example.com/doc.pdf")
|
||||
|
||||
# Verify temp file operations
|
||||
mock_temp.write.assert_called_once()
|
||||
mock_temp.close.assert_called_once()
|
||||
|
||||
@patch("httpx.get")
|
||||
def test_url_json_content_type(self, mock_get, pdf_read_fn):
|
||||
"""URL returning JSON returns appropriate error."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.content = b'{"error": "not a pdf"}'
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = pdf_read_fn(file_path="https://api.example.com/data")
|
||||
|
||||
assert "error" in result
|
||||
assert "does not point to a pdf" in result["error"].lower()
|
||||
assert "content_type" in result
|
||||
assert "application/json" in result["content_type"]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for security.py - get_secure_path() function."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -242,19 +241,14 @@ class TestGetSecurePath:
|
||||
symlink_path = session_dir / "link_to_target"
|
||||
symlink_path.symlink_to(target_file)
|
||||
|
||||
# Path through symlink should resolve
|
||||
# Path through symlink should resolve to the real target path
|
||||
result = get_secure_path("link_to_target", **ids)
|
||||
|
||||
assert result == str(symlink_path)
|
||||
# realpath resolves the symlink, so result points to the real file
|
||||
assert result == str(target_file.resolve())
|
||||
|
||||
def test_symlink_escape_detected_with_realpath(self, ids):
|
||||
"""Symlinks pointing outside sandbox can be detected using realpath.
|
||||
|
||||
Note: get_secure_path uses abspath (not realpath), so it validates the
|
||||
lexical path. To fully protect against symlink attacks, callers should
|
||||
verify realpath(result) is still within the sandbox before file I/O.
|
||||
This test documents that pattern.
|
||||
"""
|
||||
def test_symlink_escape_blocked(self, ids):
|
||||
"""Symlinks pointing outside sandbox are blocked by get_secure_path."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
# Create session directory
|
||||
@@ -267,10 +261,22 @@ class TestGetSecurePath:
|
||||
symlink_path = session_dir / "escape_link"
|
||||
symlink_path.symlink_to(outside_target)
|
||||
|
||||
# get_secure_path accepts the lexical path (symlink is inside session)
|
||||
result = get_secure_path("escape_link", **ids)
|
||||
assert result == str(symlink_path)
|
||||
# get_secure_path now resolves symlinks and blocks the escape
|
||||
with pytest.raises(ValueError, match="outside the session sandbox"):
|
||||
get_secure_path("escape_link", **ids)
|
||||
|
||||
# However, realpath reveals the escape - callers should check this
|
||||
real_path = os.path.realpath(result)
|
||||
assert os.path.commonpath([real_path, str(session_dir)]) != str(session_dir)
|
||||
def test_symlink_to_root_escape_blocked(self, ids):
|
||||
"""Symlink to / inside sandbox then traversing through it is blocked."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
# Create session directory
|
||||
session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a symlink to root filesystem inside the sandbox
|
||||
symlink_path = session_dir / "root"
|
||||
symlink_path.symlink_to("/")
|
||||
|
||||
# Attempting to access files through the symlink should be blocked
|
||||
with pytest.raises(ValueError, match="outside the session sandbox"):
|
||||
get_secure_path("root/etc/passwd", **ids)
|
||||
|
||||
Reference in New Issue
Block a user