Compare commits
235 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f748187391 | |||
| eafbeb78b4 | |||
| 5cb5083f8d | |||
| bf86daee92 | |||
| 43bbd0f31f | |||
| 2cf962b538 | |||
| 4298196700 | |||
| bc1f712e42 | |||
| cccbcc8ec3 | |||
| 0722f83f16 | |||
| 72091d2783 | |||
| 3bb69a5784 | |||
| 63fb089062 | |||
| d5ba985e29 | |||
| 6ee510d2f6 | |||
| 45b350e7c8 | |||
| 7e690de12f | |||
| ae85d2bf59 | |||
| e9fd0158b9 | |||
| 9a68a5d7ee | |||
| 33edf4a207 | |||
| f9fdaf5adc | |||
| eabb17934c | |||
| eba7524955 | |||
| c56440340a | |||
| c889ffd85d | |||
| 905a4f3516 | |||
| 941605720f | |||
| 72e5c5c1c6 | |||
| 0f42c8c8c1 | |||
| c3c3075610 | |||
| 86ef6fd8c5 | |||
| 95bdf4fe32 | |||
| 890d303d26 | |||
| 7fe60991e1 | |||
| a72938a163 | |||
| 326a3dd1b7 | |||
| 183c6e2620 | |||
| 1b40bff7da | |||
| 38b79edaee | |||
| eb4f180192 | |||
| bf0b9a1edb | |||
| 9667dd25cb | |||
| 33e4e8d440 | |||
| c5ac29c81d | |||
| 13c072d731 | |||
| 5e31975cc3 | |||
| 82af76e72a | |||
| a483f8d06a | |||
| e188c26e9f | |||
| 27a2d64a98 | |||
| c2dce3a8c2 | |||
| b52974adcc | |||
| 047ad812af | |||
| 22d9fba1fd | |||
| c7d0afc775 | |||
| 645792fb1a | |||
| 3154e34c7a | |||
| 45aafbc52b | |||
| 567340c05d | |||
| 8ecb728148 | |||
| 4a2141bce9 | |||
| 3b4d6e4602 | |||
| 8d8656193d | |||
| ef317371ce | |||
| d5596ccb0a | |||
| 89ccc664bd | |||
| 4872c01886 | |||
| 5f1530ec5b | |||
| 8af32b421c | |||
| 4620380341 | |||
| fca2deb980 | |||
| d7ce923ca6 | |||
| 403b47db61 | |||
| 0d0e78579f | |||
| 447bfdfab8 | |||
| c77d21e393 | |||
| 6ded508b4d | |||
| 75f8bf5696 | |||
| 62fc02220b | |||
| 5d4f279646 | |||
| 920a840756 | |||
| 8680a35c39 | |||
| 95cc8a4513 | |||
| d648f3d315 | |||
| b43044cf4d | |||
| 4724320946 | |||
| c9134cfd91 | |||
| 55ce751385 | |||
| aca2dfb536 | |||
| 89ab2e0a74 | |||
| d11f539209 | |||
| 64a223353a | |||
| 2d154c2db6 | |||
| a00c934d9d | |||
| 18bee9cb90 | |||
| c1664e47e5 | |||
| 2cb972fc5a | |||
| 0bd841ce01 | |||
| 88ec4b7e64 | |||
| 27d5061d97 | |||
| ee4682c565 | |||
| a2cd96a1a7 | |||
| 07b82a51f6 | |||
| 3e1282b31e | |||
| 736756b257 | |||
| 90efe7009d | |||
| 4adb369bde | |||
| d4a30eb2f3 | |||
| 94bb4a2984 | |||
| 648bad26ed | |||
| f0c7470f3d | |||
| fe533b72a6 | |||
| e581767cab | |||
| 0663ee5950 | |||
| 4b97baa34b | |||
| a89296d397 | |||
| d568912ba2 | |||
| c4d7980058 | |||
| 8549fe8238 | |||
| 2b8d85bb95 | |||
| 07f7801166 | |||
| 1f12a45151 | |||
| 936e02e8e6 | |||
| d59fe1e109 | |||
| 274318d3e5 | |||
| 0f0884c2e0 | |||
| 9b59255770 | |||
| 49fd443da8 | |||
| 764012c598 | |||
| fd4dc1a69a | |||
| 377cd39c2a | |||
| e92caeef24 | |||
| b7e6226478 | |||
| a995818db2 | |||
| 0772b4d300 | |||
| 684e0d8dc6 | |||
| d284c5d790 | |||
| 7a9b9666c4 | |||
| a852cb91bf | |||
| 2f21e9eb4b | |||
| 8390ef8731 | |||
| 8d21479c24 | |||
| 965dec3ba1 | |||
| d4b54446be | |||
| 7992b862c2 | |||
| 44b3e0eaa2 | |||
| f480fc2b94 | |||
| b599a760e8 | |||
| b4a37cdb03 | |||
| 2844dbf19f | |||
| 4885db318e | |||
| fa7ce53fb3 | |||
| 75a2ef2c4a | |||
| a0b9d6afaf | |||
| 74c0a85e3f | |||
| 22b7e4b0c3 | |||
| 5413833a69 | |||
| 02e1a4584a | |||
| 520840b1dd | |||
| ee96147336 | |||
| 705cef4dc1 | |||
| ab26e64122 | |||
| f365e219cb | |||
| 01621881c2 | |||
| f7639f8572 | |||
| fc643060ce | |||
| 9aebeb181e | |||
| acbbfaaa79 | |||
| bf170bce10 | |||
| 0a090d058b | |||
| 47bfadaad9 | |||
| d968dcd44c | |||
| 6fdaa9ea50 | |||
| 4d251fbdc2 | |||
| 6acceed288 | |||
| 8dd1d6e3aa | |||
| 1da28644a6 | |||
| 6452fe7fef | |||
| acff008bd2 | |||
| 651d6850a1 | |||
| c7fdc92594 | |||
| 43602a8801 | |||
| 3da04265a6 | |||
| 4c98f0d2d0 | |||
| d84c3364d0 | |||
| ae921f6cee | |||
| 6b506a1c08 | |||
| 0c9f4fa97e | |||
| 95e30bc607 | |||
| 0f1f0090b0 | |||
| c0da3bec02 | |||
| 9dadb5264d | |||
| e39e6a75cc | |||
| 23c66d1059 | |||
| b9d529d94e | |||
| 1c9b09fb78 | |||
| 9fb14f23d2 | |||
| 96609386a3 | |||
| 0cef0e6990 | |||
| 4795dc4f68 | |||
| acf0f804c5 | |||
| 4e2951854b | |||
| 80dfb429d7 | |||
| 9c0ba77e22 | |||
| 46b4651073 | |||
| b799789dbe | |||
| 2cd73dfccc | |||
| 57d77d5479 | |||
| 5814021773 | |||
| 4f4cc9c8ce | |||
| d9c840eee5 | |||
| 03842353e4 | |||
| 88253883a3 | |||
| 6ed6e5b286 | |||
| 30bb0ad5d8 | |||
| cb0845f5ba | |||
| ce2525b59c | |||
| 1f77ec3831 | |||
| 6ab5aa8004 | |||
| 4449cd8ee8 | |||
| 8b60c03a0a | |||
| 2f15a16159 | |||
| 0e98023e40 | |||
| 22bb07f00e | |||
| 660f883197 | |||
| 988de80b66 | |||
| dc6aa226ee | |||
| 48a54b4ee2 | |||
| a7b6b080ab | |||
| 9202cbd4d4 | |||
| d433cda209 | |||
| 151fbd7b00 | |||
| f88483f964 | |||
| b61ec8c94d |
@@ -2,14 +2,22 @@ name: Bounty completed
|
||||
description: Awards points and notifies Discord when a bounty PR is merged
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr_number:
|
||||
description: "PR number to process (for missed bounties)"
|
||||
required: true
|
||||
type: number
|
||||
|
||||
jobs:
|
||||
bounty-notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:')
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:'))
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
@@ -32,6 +40,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_NUMBER: ${{ inputs.pr_number || github.event.pull_request.number }}
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
name: Link Discord account
|
||||
description: Auto-creates a PR to add contributor to contributors.yml when a link-discord issue is opened
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
link-discord:
|
||||
if: contains(github.event.issue.labels.*.name, 'link-discord')
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 2
|
||||
permissions:
|
||||
contents: write
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Parse issue and update contributors.yml
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
|
||||
const issue = context.payload.issue;
|
||||
const githubUsername = issue.user.login;
|
||||
|
||||
// Parse the issue body for form fields
|
||||
const body = issue.body || '';
|
||||
|
||||
// Extract Discord ID — look for the numeric value after the "Discord User ID" heading
|
||||
const discordMatch = body.match(/### Discord User ID\s*\n\s*(\d{17,20})/);
|
||||
if (!discordMatch) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `Could not find a valid Discord ID in the issue body. Please make sure you entered a numeric ID (17-20 digits), not a username.\n\nExample: \`123456789012345678\``
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'not_planned'
|
||||
});
|
||||
return;
|
||||
}
|
||||
const discordId = discordMatch[1];
|
||||
|
||||
// Extract display name (optional)
|
||||
const nameMatch = body.match(/### Display Name \(optional\)\s*\n\s*(.+)/);
|
||||
const displayName = nameMatch ? nameMatch[1].trim() : '';
|
||||
|
||||
// Check if user already exists
|
||||
const yml = fs.readFileSync('contributors.yml', 'utf-8');
|
||||
if (yml.includes(`github: ${githubUsername}`)) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `@${githubUsername} is already in \`contributors.yml\`. If you need to update your Discord ID, please edit the file directly via PR.`
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'completed'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Append entry to contributors.yml
|
||||
let entry = ` - github: ${githubUsername}\n discord: "${discordId}"`;
|
||||
if (displayName && displayName !== '_No response_') {
|
||||
entry += `\n name: ${displayName}`;
|
||||
}
|
||||
entry += '\n';
|
||||
|
||||
const updated = yml.trimEnd() + '\n' + entry;
|
||||
fs.writeFileSync('contributors.yml', updated);
|
||||
|
||||
// Set outputs for commit step
|
||||
core.exportVariable('GITHUB_USERNAME', githubUsername);
|
||||
core.exportVariable('DISCORD_ID', discordId);
|
||||
core.exportVariable('ISSUE_NUMBER', issue.number.toString());
|
||||
|
||||
- name: Create PR
|
||||
run: |
|
||||
# Check if there are changes
|
||||
if git diff --quiet contributors.yml; then
|
||||
echo "No changes to contributors.yml"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
BRANCH="docs/link-discord-${GITHUB_USERNAME}"
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git checkout -b "$BRANCH"
|
||||
git add contributors.yml
|
||||
git commit -m "docs: link @${GITHUB_USERNAME} to Discord"
|
||||
git push origin "$BRANCH"
|
||||
|
||||
gh pr create \
|
||||
--title "docs: link @${GITHUB_USERNAME} to Discord" \
|
||||
--body "Adds @${GITHUB_USERNAME} (Discord \`${DISCORD_ID}\`) to \`contributors.yml\` for bounty XP tracking.
|
||||
|
||||
Closes #${ISSUE_NUMBER}" \
|
||||
--base main \
|
||||
--head "$BRANCH" \
|
||||
--label "link-discord"
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Notify on issue
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const username = process.env.GITHUB_USERNAME;
|
||||
const issueNumber = parseInt(process.env.ISSUE_NUMBER);
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issueNumber,
|
||||
body: `A PR has been created to link your account. A maintainer will merge it shortly — once merged, you'll receive XP and Discord pings when your bounty PRs are merged.`
|
||||
});
|
||||
@@ -35,6 +35,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
SINCE_DATE: ${{ github.event.inputs.since_date || '' }}
|
||||
|
||||
+4
-2
@@ -13,6 +13,10 @@ out/
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
.venv
|
||||
/venv
|
||||
tools/src/uv.lock
|
||||
|
||||
|
||||
# User configuration (copied from .example)
|
||||
config.yaml
|
||||
@@ -69,8 +73,6 @@ exports/*
|
||||
|
||||
.claude/settings.local.json
|
||||
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
|
||||
+8
-2
@@ -4,7 +4,7 @@
|
||||
|
||||
Welcome to Aden Hive, an open-source AI agent framework built for developers who demand production-grade reliability, cross-platform support, and real-world performance. This guide will help you contribute effectively, whether you're fixing bugs, adding features, improving documentation, or building new tools.
|
||||
|
||||
Thank you for your interest in contributing! We're especially looking for help building tools, integrations ([check #2805](https://github.com/adenhq/hive/issues/2805)), and example agents for the framework.
|
||||
Thank you for your interest in contributing! We're especially looking for help building tools, integrations ([check #2805](https://github.com/aden-hive/hive/issues/2805)), and example agents for the framework.
|
||||
|
||||
---
|
||||
|
||||
@@ -390,6 +390,8 @@ Aden Hive supports **100+ LLM providers** via LiteLLM, giving users maximum flex
|
||||
|----------|--------|-------|
|
||||
| **Anthropic** | Claude 3.5 Sonnet, Haiku, Opus | Default provider, best for reasoning |
|
||||
| **OpenAI** | GPT-4, GPT-4 Turbo, GPT-4o | Function calling, vision |
|
||||
| **OpenRouter** | Any OpenRouter catalog model | Uses `OPENROUTER_API_KEY` and `https://openrouter.ai/api/v1` |
|
||||
| **Hive LLM** | `queen`, `kimi-2.5`, `GLM-5` | Uses `HIVE_API_KEY` and the Hive-managed endpoint |
|
||||
| **Google** | Gemini 1.5 Pro, Flash | Long context windows |
|
||||
| **DeepSeek** | DeepSeek V3 | Cost-effective, strong reasoning |
|
||||
| **Mistral** | Mistral Large, Medium, Small | Open weights, EU hosting |
|
||||
@@ -415,6 +417,10 @@ DEFAULT_MODEL = "claude-haiku-4-5-20251001"
|
||||
- **Cost**: DeepSeek or Gemini Flash (budget-conscious)
|
||||
- **Privacy**: Ollama with local models (no data leaves server)
|
||||
|
||||
**Provider-Specific Notes**
|
||||
- **OpenRouter**: store `provider` as `openrouter`, use the raw OpenRouter model ID in `model` (for example `x-ai/grok-4.20-beta`), and use `OPENROUTER_API_KEY`
|
||||
- **Hive LLM**: store `provider` as `hive`, use Hive model names such as `queen`, `kimi-2.5`, or `GLM-5`, and use `HIVE_API_KEY`
|
||||
|
||||
**For Development**
|
||||
- Use cheaper/faster models (Haiku, GPT-4o-mini)
|
||||
- Test with multiple providers to catch provider-specific issues
|
||||
@@ -426,7 +432,7 @@ DEFAULT_MODEL = "claude-haiku-4-5-20251001"
|
||||
2. **Add credential handling** in `core/framework/credentials/`
|
||||
3. **Add provider-specific configuration** in `core/framework/llm/`
|
||||
4. **Write tests** in `core/tests/test_llm_provider.py`
|
||||
5. **Update documentation** in `docs/llm_providers.md`
|
||||
5. **Update documentation** in `README.md`, `docs/configuration.md`, and any setup guides that mention provider configuration
|
||||
|
||||
**Example: Testing LLM Integration**
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Agent_Harness-Runtime_Layer-ff6600?style=flat-square" alt="Agent Harness" />
|
||||
<img src="https://img.shields.io/badge/AI_Agents-Self--Improving-brightgreen?style=flat-square" alt="AI Agents" />
|
||||
<img src="https://img.shields.io/badge/Multi--Agent-Systems-blue?style=flat-square" alt="Multi-Agent" />
|
||||
<img src="https://img.shields.io/badge/Headless-Development-purple?style=flat-square" alt="Headless" />
|
||||
@@ -35,37 +36,42 @@
|
||||
<img src="https://img.shields.io/badge/Google_Gemini-supported-4285F4?style=flat-square&logo=google" alt="Gemini" />
|
||||
</p>
|
||||
|
||||
<p align="center"><em>The agent harness for production workloads — state management, failure recovery, observability, and human oversight so your agents actually run.</em></p>
|
||||
|
||||
## Overview
|
||||
|
||||
Generate a swarm of worker agents with a coding agent(queen) that control them. Define your goal through conversation with hive queen, and the framework generates a node graph with dynamically created connection code. When things break, the framework captures failure data, evolves the agent through the coding agent, and redeploys. Built-in human-in-the-loop nodes, browser use, credential management, and real-time monitoring give you control without sacrificing adaptability.
|
||||
Hive is a runtime harness for AI agents in production. You describe your goal in natural language; a coding agent (the queen) generates the agent graph and connection code to achieve it. During execution, the harness manages state isolation, checkpoint-based crash recovery, cost enforcement, and real-time observability. When agents fail, the framework captures failure data, evolves the graph through the coding agent, and redeploys automatically. Built-in human-in-the-loop nodes, browser control, credential management, and parallel execution give you production reliability without sacrificing adaptability.
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
[](https://www.youtube.com/watch?v=XDOG9fOaLjU)
|
||||
Visit [HoneyComb](http://honeycomb.open-hive.com/) to see what jobs are being automated by AI. It’s a stock market for jobs, driven by our community’s AI agent progress. You can long and short jobs (with no real money but compute token)based on how much you think a job is going to be replaced by AI.
|
||||
|
||||
https://github.com/user-attachments/assets/bf10edc3-06ba-48b6-98ba-d069b15fb69d
|
||||
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
Hive is designed for developers and teams who want to build many **autonomous AI agents** fast without manually wiring complex workflows.
|
||||
Hive is the harness layer for teams moving AI agents from prototype to production. Models are getting better on their own — the bottleneck is the infrastructure around them: state management, failure recovery, cost control, and observability.
|
||||
|
||||
Hive is a good fit if you:
|
||||
|
||||
- Want AI agents that **execute real business processes**, not demos
|
||||
- Need **fast or high volume agent execution** over open workflow
|
||||
- Need a **runtime that handles state, recovery, and parallel execution** at scale
|
||||
- Need **self-healing and adaptive agents** that improve over time
|
||||
- Require **human-in-the-loop control**, observability, and cost limits
|
||||
- Plan to run agents in **production environments**
|
||||
- Plan to run agents in **production** where uptime, cost, and auditability matter
|
||||
|
||||
Hive may not be the best fit if you’re only experimenting with simple agent chains or one-off scripts.
|
||||
|
||||
## When Should You Use Hive?
|
||||
|
||||
Use Hive when you need:
|
||||
Use Hive when the bottleneck is no longer the model but the harness around it:
|
||||
|
||||
- Long-running, autonomous agents
|
||||
- Strong guardrails, process, and controls
|
||||
- Continuous improvement based on failures
|
||||
- Multi-agent coordination
|
||||
- A framework that evolves with your goals
|
||||
- Long-running agents that need **state persistence and crash recovery**
|
||||
- Production workloads requiring **cost enforcement, observability, and audit trails**
|
||||
- Agents that **self-heal** through failure capture and graph evolution
|
||||
- Multi-agent coordination with **session isolation and shared memory**
|
||||
- A framework that **scales with model improvements** rather than fighting them
|
||||
|
||||
## Quick Links
|
||||
|
||||
@@ -73,7 +79,7 @@ Use Hive when you need:
|
||||
- **[Self-Hosting Guide](https://docs.adenhq.com/getting-started/quickstart)** - Deploy Hive on your infrastructure
|
||||
- **[Changelog](https://github.com/aden-hive/hive/releases)** - Latest updates and releases
|
||||
- **[Roadmap](docs/roadmap.md)** - Upcoming features and plans
|
||||
- **[Report Issues](https://github.com/adenhq/hive/issues)** - Bug reports and feature requests
|
||||
- **[Report Issues](https://github.com/aden-hive/hive/issues)** - Bug reports and feature requests
|
||||
- **[Contributing](CONTRIBUTING.md)** - How to contribute and submit PRs
|
||||
|
||||
## Quick Start
|
||||
@@ -98,9 +104,11 @@ Use Hive when you need:
|
||||
git clone https://github.com/aden-hive/hive.git
|
||||
cd hive
|
||||
|
||||
|
||||
# Run quickstart setup
|
||||
# Run quickstart setup (macOS/Linux)
|
||||
./quickstart.sh
|
||||
|
||||
# Windows (PowerShell)
|
||||
.\quickstart.ps1
|
||||
```
|
||||
|
||||
This sets up:
|
||||
@@ -108,7 +116,7 @@ This sets up:
|
||||
- **framework** - Core agent runtime and graph executor (in `core/.venv`)
|
||||
- **aden_tools** - MCP tools for agent capabilities (in `tools/.venv`)
|
||||
- **credential store** - Encrypted API key storage (`~/.hive/credentials`)
|
||||
- **LLM provider** - Interactive default model configuration
|
||||
- **LLM provider** - Interactive default model configuration, including Hive LLM and OpenRouter
|
||||
- All required Python dependencies with `uv`
|
||||
|
||||
- Finally, it will open the Hive interface in your browser
|
||||
@@ -147,12 +155,12 @@ Now you can run an agent by selecting the agent (either an existing agent or exa
|
||||
<a href="https://github.com/aden-hive/hive/tree/main/tools/src/aden_tools/tools"><img width="100%" alt="Integration" src="https://github.com/user-attachments/assets/a1573f93-cf02-4bb8-b3d5-b305b05b1e51" /></a>
|
||||
Hive is built to be model-agnostic and system-agnostic.
|
||||
|
||||
- **LLM flexibility** - Hive Framework is designed to support various types of LLMs, including hosted and local models through LiteLLM-compatible providers.
|
||||
- **LLM flexibility** - Hive Framework supports Anthropic, OpenAI, OpenRouter, Hive LLM, and other hosted or local models through LiteLLM-compatible providers.
|
||||
- **Business system connectivity** - Hive Framework is designed to connect to all kinds of business systems as tools, such as CRM, support, messaging, data, file, and internal APIs via MCP.
|
||||
|
||||
## Why Aden
|
||||
## Why Hive
|
||||
|
||||
Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
|
||||
As models improve, the upper bound of what agents can do rises — but their reliability and production value are determined by the harness. Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
@@ -188,8 +196,9 @@ flowchart LR
|
||||
|
||||
### The Hive Advantage
|
||||
|
||||
| Traditional Frameworks | Hive |
|
||||
| Typical Agent Frameworks | Hive |
|
||||
| -------------------------- | -------------------------------------- |
|
||||
| Focus on model orchestration | **Production harness**: state, recovery, observability |
|
||||
| Hardcode agent workflows | Describe goals in natural language |
|
||||
| Manual graph definition | Auto-generated agent graphs |
|
||||
| Reactive error handling | Outcome-evaluation and adaptiveness |
|
||||
@@ -375,7 +384,7 @@ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENS
|
||||
|
||||
**Q: What LLM providers does Hive support?**
|
||||
|
||||
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name. We recommend using Claude, GLM and Gemini as they have the best performance.
|
||||
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, OpenRouter, and Hive LLM. Simply set the appropriate API key environment variable and specify the model name. See [docs/configuration.md](docs/configuration.md) for provider-specific configuration examples.
|
||||
|
||||
**Q: Can I use Hive with local AI models like Ollama?**
|
||||
|
||||
@@ -383,7 +392,7 @@ Yes! Hive supports local models through LiteLLM. Simply use the model name forma
|
||||
|
||||
**Q: What makes Hive different from other agent frameworks?**
|
||||
|
||||
Hive generates your entire agent system from natural language goals using a coding agent—you don't hardcode workflows or manually define graphs. When agents fail, the framework automatically captures failure data, [evolves the agent graph](docs/key_concepts/evolution.md), and redeploys. This self-improving loop is unique to Aden.
|
||||
Hive is an agent harness, not just an orchestration framework. It provides the production runtime layer — session isolation, checkpoint-based crash recovery, cost enforcement, real-time observability, and human-in-the-loop controls — that makes agents reliable enough to run real workloads. On top of that, Hive generates your entire agent system from natural language goals and automatically [evolves the graph](docs/key_concepts/evolution.md) when agents fail. The combination of a robust harness with self-improving generation is what sets Hive apart.
|
||||
|
||||
**Q: Is Hive open-source?**
|
||||
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
perf: reduce subprocess spawning in quickstart scripts (#4427)
|
||||
|
||||
## Problem
|
||||
Windows process creation (CreateProcess) is 10-100x slower than Linux fork/exec.
|
||||
The quickstart scripts were spawning 4+ separate `uv run python -c "import X"`
|
||||
processes to verify imports, adding ~600ms overhead on Windows.
|
||||
|
||||
## Solution
|
||||
Consolidated all import checks into a single batch script that checks multiple
|
||||
modules in one subprocess call, reducing spawn overhead by ~75%.
|
||||
|
||||
## Changes
|
||||
- **New**: `scripts/check_requirements.py` - Batched import checker
|
||||
- **New**: `scripts/test_check_requirements.py` - Test suite
|
||||
- **New**: `scripts/benchmark_quickstart.ps1` - Performance benchmark tool
|
||||
- **Modified**: `quickstart.ps1` - Updated import verification (2 sections)
|
||||
- **Modified**: `quickstart.sh` - Updated import verification
|
||||
|
||||
## Performance Impact
|
||||
**Benchmark results on Windows:**
|
||||
- Before: ~19.8 seconds for import checks
|
||||
- After: ~4.9 seconds for import checks
|
||||
- **Improvement: 14.9 seconds saved (75.2% faster)**
|
||||
|
||||
## Testing
|
||||
- ✅ All functional tests pass (`scripts/test_check_requirements.py`)
|
||||
- ✅ Quickstart scripts work correctly on Windows
|
||||
- ✅ Error handling verified (invalid imports reported correctly)
|
||||
- ✅ Performance benchmark confirms 75%+ improvement
|
||||
|
||||
Fixes #4427
|
||||
@@ -1,27 +0,0 @@
|
||||
# Identity mapping: GitHub username -> Discord ID
|
||||
#
|
||||
# This file links GitHub accounts to Discord accounts for the
|
||||
# Integration Bounty Program. When a bounty PR is merged, the
|
||||
# GitHub Action uses this file to ping the contributor on Discord.
|
||||
#
|
||||
# HOW TO ADD YOURSELF:
|
||||
# Open a "Link Discord Account" issue:
|
||||
# https://github.com/aden-hive/hive/issues/new?template=link-discord.yml
|
||||
# A GitHub Action will automatically add your entry here.
|
||||
#
|
||||
# To find your Discord ID:
|
||||
# 1. Open Discord Settings > Advanced > Enable Developer Mode
|
||||
# 2. Right-click your name > Copy User ID
|
||||
#
|
||||
# Format:
|
||||
# - github: your-github-username
|
||||
# discord: "your-discord-id" # quotes required (it's a number)
|
||||
# name: Your Display Name # optional
|
||||
|
||||
contributors:
|
||||
# - github: example-user
|
||||
# discord: "123456789012345678"
|
||||
# name: Example User
|
||||
- github: TimothyZhang7
|
||||
discord: "408460790061072384"
|
||||
name: Timothy@Aden
|
||||
@@ -6,7 +6,7 @@ This guide explains how to integrate Model Context Protocol (MCP) servers with t
|
||||
|
||||
The framework provides built-in support for MCP servers, allowing you to:
|
||||
|
||||
- **Register MCP servers** via STDIO or HTTP transport
|
||||
- **Register MCP servers** via STDIO, HTTP, Unix socket, or SSE transport
|
||||
- **Auto-discover tools** from registered servers
|
||||
- **Use MCP tools** seamlessly in your agents
|
||||
- **Manage multiple MCP servers** simultaneously
|
||||
@@ -104,6 +104,48 @@ runner.register_mcp_server(
|
||||
- `url`: Base URL of the MCP server
|
||||
- `headers`: HTTP headers to include (optional)
|
||||
|
||||
### Unix Socket Transport
|
||||
|
||||
Best for same-host inter-process communication with lower overhead than TCP:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="local-ipc-tools",
|
||||
transport="unix",
|
||||
url="http://localhost",
|
||||
socket_path="/tmp/mcp_server.sock",
|
||||
headers={
|
||||
"Authorization": "Bearer token"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `url`: Base URL for HTTP requests over the socket (required, e.g., `"http://localhost"`)
|
||||
- `socket_path`: Absolute path to the Unix socket file (required, e.g., `"/tmp/mcp_server.sock"`)
|
||||
- `headers`: HTTP headers to include (optional)
|
||||
|
||||
### SSE Transport
|
||||
|
||||
Best for real-time, event-driven connections using the MCP SDK's SSE client:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="streaming-tools",
|
||||
transport="sse",
|
||||
url="http://localhost:8000/sse",
|
||||
headers={
|
||||
"Authorization": "Bearer token"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `url`: SSE endpoint URL (required, e.g., `"http://localhost:8000/sse"`)
|
||||
- `headers`: HTTP headers for the SSE connection (optional)
|
||||
|
||||
## Using MCP Tools in Agents
|
||||
|
||||
Once registered, MCP tools are available just like any other tool:
|
||||
@@ -258,7 +300,32 @@ runner.register_mcp_server(
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Handle Cleanup
|
||||
### 3. Use Unix Socket for Same-Host IPC
|
||||
|
||||
When both the agent and MCP server run on the same machine, Unix sockets avoid TCP overhead:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="fast-local-tools",
|
||||
transport="unix",
|
||||
url="http://localhost",
|
||||
socket_path="/tmp/mcp_server.sock"
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Use SSE for Streaming and Real-Time Tools
|
||||
|
||||
SSE transport maintains a persistent connection, ideal for event-driven servers:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="realtime-tools",
|
||||
transport="sse",
|
||||
url="http://realtime-server:8000/sse"
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Handle Cleanup
|
||||
|
||||
Always clean up MCP connections when done:
|
||||
|
||||
@@ -280,7 +347,7 @@ async with AgentRunner.load("exports/my-agent") as runner:
|
||||
# Automatic cleanup
|
||||
```
|
||||
|
||||
### 4. Tool Name Conflicts
|
||||
### 6. Tool Name Conflicts
|
||||
|
||||
If multiple MCP servers provide tools with the same name, the last registered server wins. To avoid conflicts:
|
||||
|
||||
@@ -315,6 +382,24 @@ If HTTP transport fails:
|
||||
2. Check firewall settings
|
||||
3. Verify the URL and port are correct
|
||||
|
||||
### Unix Socket Not Connecting
|
||||
|
||||
If Unix socket transport fails:
|
||||
|
||||
1. Verify the socket file exists: `ls -la /tmp/mcp_server.sock`
|
||||
2. Check file permissions on the socket
|
||||
3. Ensure no other process has locked the socket
|
||||
4. Verify the `url` field is set (e.g., `"http://localhost"`)
|
||||
|
||||
### SSE Connection Issues
|
||||
|
||||
If SSE transport fails:
|
||||
|
||||
1. Verify the server supports SSE at the given URL
|
||||
2. Check that the `mcp` Python package is installed (`pip install mcp`)
|
||||
3. Ensure the SSE endpoint is accessible: `curl http://localhost:8000/sse`
|
||||
4. Check for firewall or proxy issues blocking long-lived connections
|
||||
|
||||
## Example: Full Agent with MCP Tools
|
||||
|
||||
Here's a complete example of an agent that uses MCP tools:
|
||||
|
||||
@@ -0,0 +1,583 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Antigravity authentication CLI.
|
||||
|
||||
Implements OAuth2 flow for Google's Antigravity Code Assist gateway.
|
||||
Credentials are stored in ~/.hive/antigravity-accounts.json.
|
||||
|
||||
Usage:
|
||||
python -m antigravity_auth auth account add
|
||||
python -m antigravity_auth auth account list
|
||||
python -m antigravity_auth auth account remove <email>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth endpoints
|
||||
_OAUTH_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Scopes for Antigravity/Cloud Code Assist
|
||||
_OAUTH_SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
]
|
||||
|
||||
# Credentials file path in ~/.hive/
|
||||
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
|
||||
# Default project ID
|
||||
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
|
||||
_DEFAULT_REDIRECT_PORT = 51121
|
||||
|
||||
# OAuth credentials fetched from the opencode-antigravity-auth project.
|
||||
# This project reverse-engineered and published the public OAuth credentials
|
||||
# for Google's Antigravity/Cloud Code Assist API.
|
||||
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
|
||||
_CREDENTIALS_URL = (
|
||||
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
|
||||
)
|
||||
|
||||
# Cached credentials fetched from public source
|
||||
_cached_client_id: str | None = None
|
||||
_cached_client_secret: str | None = None
|
||||
|
||||
|
||||
def _fetch_credentials_from_public_source() -> tuple[str | None, str | None]:
|
||||
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
|
||||
global _cached_client_id, _cached_client_secret
|
||||
if _cached_client_id and _cached_client_secret:
|
||||
return _cached_client_id, _cached_client_secret
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
_CREDENTIALS_URL, headers={"User-Agent": "Hive-Antigravity-Auth/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
content = resp.read().decode("utf-8")
|
||||
import re
|
||||
|
||||
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
|
||||
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
|
||||
if id_match:
|
||||
_cached_client_id = id_match.group(1)
|
||||
if secret_match:
|
||||
_cached_client_secret = secret_match.group(1)
|
||||
return _cached_client_id, _cached_client_secret
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to fetch credentials from public source: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_client_id() -> str:
|
||||
"""Get OAuth client ID from env, config, or public source."""
|
||||
env_id = os.environ.get("ANTIGRAVITY_CLIENT_ID")
|
||||
if env_id:
|
||||
return env_id
|
||||
|
||||
# Try hive config
|
||||
hive_cfg = Path.home() / ".hive" / "configuration.json"
|
||||
if hive_cfg.exists():
|
||||
try:
|
||||
with open(hive_cfg) as f:
|
||||
cfg = json.load(f)
|
||||
cfg_id = cfg.get("llm", {}).get("antigravity_client_id")
|
||||
if cfg_id:
|
||||
return cfg_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch from public source
|
||||
client_id, _ = _fetch_credentials_from_public_source()
|
||||
if client_id:
|
||||
return client_id
|
||||
|
||||
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
|
||||
|
||||
|
||||
def get_client_secret() -> str | None:
|
||||
"""Get OAuth client secret from env, config, or public source."""
|
||||
secret = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
|
||||
if secret:
|
||||
return secret
|
||||
|
||||
# Try to read from hive config
|
||||
hive_cfg = Path.home() / ".hive" / "configuration.json"
|
||||
if hive_cfg.exists():
|
||||
try:
|
||||
with open(hive_cfg) as f:
|
||||
cfg = json.load(f)
|
||||
secret = cfg.get("llm", {}).get("antigravity_client_secret")
|
||||
if secret:
|
||||
return secret
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch from public source (npm package on GitHub)
|
||||
_, secret = _fetch_credentials_from_public_source()
|
||||
return secret
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find an available local port."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
class OAuthCallbackHandler(BaseHTTPRequestHandler):
|
||||
"""Handle OAuth callback from browser."""
|
||||
|
||||
auth_code: str | None = None
|
||||
state: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
pass # Suppress default logging
|
||||
|
||||
def do_GET(self) -> None:
|
||||
parsed = urllib.parse.urlparse(self.path)
|
||||
|
||||
if parsed.path == "/oauth-callback":
|
||||
query = urllib.parse.parse_qs(parsed.query)
|
||||
|
||||
if "error" in query:
|
||||
self.error = query["error"][0]
|
||||
self._send_response("Authentication failed. You can close this window.")
|
||||
return
|
||||
|
||||
if "code" in query and "state" in query:
|
||||
OAuthCallbackHandler.auth_code = query["code"][0]
|
||||
OAuthCallbackHandler.state = query["state"][0]
|
||||
self._send_response(
|
||||
"Authentication successful! You can close this window "
|
||||
"and return to the terminal."
|
||||
)
|
||||
return
|
||||
|
||||
self._send_response("Waiting for authentication...")
|
||||
|
||||
def _send_response(self, message: str) -> None:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Antigravity Auth</title></head>
|
||||
<body style="font-family: system-ui; display: flex; align-items: center;
|
||||
justify-content: center; height: 100vh; margin: 0; background: #1a1a2e;
|
||||
color: #eee;">
|
||||
<div style="text-align: center;">
|
||||
<h2>{message}</h2>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
self.wfile.write(html.encode())
|
||||
|
||||
|
||||
def wait_for_callback(port: int, timeout: int = 300) -> tuple[str | None, str | None, str | None]:
|
||||
"""Start local server and wait for OAuth callback."""
|
||||
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
|
||||
server.timeout = 1
|
||||
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if OAuthCallbackHandler.auth_code:
|
||||
return (
|
||||
OAuthCallbackHandler.auth_code,
|
||||
OAuthCallbackHandler.state,
|
||||
OAuthCallbackHandler.error,
|
||||
)
|
||||
server.handle_request()
|
||||
|
||||
return None, None, "timeout"
|
||||
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
code: str, redirect_uri: str, client_id: str, client_secret: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Exchange authorization code for tokens."""
|
||||
data = {
|
||||
"code": code,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
body = urllib.parse.urlencode(data).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return json.loads(resp.read())
|
||||
except Exception as e:
|
||||
logger.error(f"Token exchange failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_user_email(access_token: str) -> str | None:
|
||||
"""Get user email from Google API."""
|
||||
req = urllib.request.Request(
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data.get("email")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_accounts() -> dict[str, Any]:
|
||||
"""Load existing accounts from file."""
|
||||
if not _ACCOUNTS_FILE.exists():
|
||||
return {"schemaVersion": 4, "accounts": []}
|
||||
try:
|
||||
with open(_ACCOUNTS_FILE) as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {"schemaVersion": 4, "accounts": []}
|
||||
|
||||
|
||||
def save_accounts(data: dict[str, Any]) -> None:
|
||||
"""Save accounts to file."""
|
||||
_ACCOUNTS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(_ACCOUNTS_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
logger.info(f"Saved credentials to {_ACCOUNTS_FILE}")
|
||||
|
||||
|
||||
def validate_credentials(access_token: str, project_id: str = _DEFAULT_PROJECT_ID) -> bool:
|
||||
"""Test if credentials work by making a simple API call to Antigravity.
|
||||
|
||||
Returns True if credentials are valid, False otherwise.
|
||||
"""
|
||||
endpoint = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
body = {
|
||||
"project": project_id,
|
||||
"model": "gemini-3-flash",
|
||||
"request": {
|
||||
"contents": [{"role": "user", "parts": [{"text": "hi"}]}],
|
||||
"generationConfig": {"maxOutputTokens": 10},
|
||||
},
|
||||
"requestType": "agent",
|
||||
"userAgent": "antigravity",
|
||||
"requestId": "validation-test",
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) Antigravity/1.18.3"
|
||||
),
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
}
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{endpoint}/v1internal:generateContent",
|
||||
data=json.dumps(body).encode("utf-8"),
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
json.loads(resp.read())
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def refresh_access_token(
|
||||
refresh_token: str, client_id: str, client_secret: str | None
|
||||
) -> dict | None:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
|
||||
body = urllib.parse.urlencode(data).encode()
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return json.loads(resp.read())
|
||||
except Exception as e:
|
||||
logger.debug(f"Token refresh failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def cmd_account_add(args: argparse.Namespace) -> int:
|
||||
"""Add a new Antigravity account via OAuth2.
|
||||
|
||||
First checks if valid credentials already exist. If so, validates them
|
||||
and skips OAuth if they work. Otherwise, proceeds with OAuth flow.
|
||||
"""
|
||||
client_id = get_client_id()
|
||||
client_secret = get_client_secret()
|
||||
|
||||
# Check if credentials already exist
|
||||
accounts_data = load_accounts()
|
||||
accounts = accounts_data.get("accounts", [])
|
||||
|
||||
if accounts:
|
||||
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
|
||||
access_token = account.get("access")
|
||||
refresh_token_str = account.get("refresh", "")
|
||||
refresh_token = refresh_token_str.split("|")[0] if refresh_token_str else None
|
||||
project_id = (
|
||||
refresh_token_str.split("|")[1] if "|" in refresh_token_str else _DEFAULT_PROJECT_ID
|
||||
)
|
||||
email = account.get("email", "unknown")
|
||||
expires_ms = account.get("expires", 0)
|
||||
expires_at = expires_ms / 1000.0 if expires_ms else 0.0
|
||||
|
||||
# Check if token is expired or near expiry
|
||||
if access_token and expires_at and time.time() < expires_at - 60:
|
||||
# Token still valid, test it
|
||||
logger.info(f"Found existing credentials for: {email}")
|
||||
logger.info("Validating existing credentials...")
|
||||
if validate_credentials(access_token, project_id):
|
||||
logger.info("✓ Credentials valid! Skipping OAuth.")
|
||||
return 0
|
||||
else:
|
||||
logger.info("Credentials failed validation, refreshing...")
|
||||
elif refresh_token:
|
||||
logger.info(f"Found expired credentials for: {email}")
|
||||
logger.info("Attempting token refresh...")
|
||||
|
||||
tokens = refresh_access_token(refresh_token, client_id, client_secret)
|
||||
if tokens:
|
||||
new_access = tokens.get("access_token")
|
||||
expires_in = tokens.get("expires_in", 3600)
|
||||
if new_access:
|
||||
# Update the account
|
||||
account["access"] = new_access
|
||||
account["expires"] = int((time.time() + expires_in) * 1000)
|
||||
accounts_data["last_refresh"] = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ", time.gmtime()
|
||||
)
|
||||
save_accounts(accounts_data)
|
||||
|
||||
# Validate the refreshed token
|
||||
logger.info("Validating refreshed credentials...")
|
||||
if validate_credentials(new_access, project_id):
|
||||
logger.info("✓ Credentials refreshed and validated!")
|
||||
return 0
|
||||
else:
|
||||
logger.info("Refreshed token failed validation, proceeding with OAuth...")
|
||||
else:
|
||||
logger.info("Token refresh failed, proceeding with OAuth...")
|
||||
|
||||
# No valid credentials, proceed with OAuth
|
||||
if not client_secret:
|
||||
logger.warning(
|
||||
"No client secret configured. Token refresh may fail.\n"
|
||||
"Set ANTIGRAVITY_CLIENT_SECRET env var or add "
|
||||
"'antigravity_client_secret' to ~/.hive/configuration.json"
|
||||
)
|
||||
|
||||
# Use fixed port and path matching Google's expected OAuth redirect URI
|
||||
port = _DEFAULT_REDIRECT_PORT
|
||||
redirect_uri = f"http://localhost:{port}/oauth-callback"
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(16)
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(_OAUTH_SCOPES),
|
||||
"state": state,
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
auth_url = f"{_OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
logger.info("Opening browser for authentication...")
|
||||
logger.info(f"If the browser doesn't open, visit: {auth_url}\n")
|
||||
|
||||
# Open browser
|
||||
webbrowser.open(auth_url)
|
||||
|
||||
# Wait for callback
|
||||
logger.info(f"Listening for callback on port {port}...")
|
||||
code, received_state, error = wait_for_callback(port)
|
||||
|
||||
if error:
|
||||
logger.error(f"Authentication failed: {error}")
|
||||
return 1
|
||||
|
||||
if not code:
|
||||
logger.error("No authorization code received")
|
||||
return 1
|
||||
|
||||
if received_state != state:
|
||||
logger.error("State mismatch - possible CSRF attack")
|
||||
return 1
|
||||
|
||||
# Exchange code for tokens
|
||||
logger.info("Exchanging authorization code for tokens...")
|
||||
tokens = exchange_code_for_tokens(code, redirect_uri, client_id, client_secret)
|
||||
|
||||
if not tokens:
|
||||
return 1
|
||||
|
||||
access_token = tokens.get("access_token")
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
expires_in = tokens.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
logger.error("No access token in response")
|
||||
return 1
|
||||
|
||||
# Get user email
|
||||
email = get_user_email(access_token)
|
||||
if email:
|
||||
logger.info(f"Authenticated as: {email}")
|
||||
|
||||
# Load existing accounts and add/update
|
||||
accounts_data = load_accounts()
|
||||
accounts = accounts_data.get("accounts", [])
|
||||
|
||||
# Build new account entry (V4 schema)
|
||||
expires_ms = int((time.time() + expires_in) * 1000)
|
||||
refresh_entry = f"{refresh_token}|{_DEFAULT_PROJECT_ID}"
|
||||
|
||||
new_account = {
|
||||
"access": access_token,
|
||||
"refresh": refresh_entry,
|
||||
"expires": expires_ms,
|
||||
"email": email,
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# Update existing account or add new one
|
||||
existing_idx = next((i for i, a in enumerate(accounts) if a.get("email") == email), None)
|
||||
if existing_idx is not None:
|
||||
accounts[existing_idx] = new_account
|
||||
logger.info(f"Updated existing account: {email}")
|
||||
else:
|
||||
accounts.append(new_account)
|
||||
logger.info(f"Added new account: {email}")
|
||||
|
||||
accounts_data["accounts"] = accounts
|
||||
accounts_data["schemaVersion"] = 4
|
||||
accounts_data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
|
||||
save_accounts(accounts_data)
|
||||
logger.info("\n✓ Authentication complete!")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_account_list(args: argparse.Namespace) -> int:
|
||||
"""List all stored accounts."""
|
||||
data = load_accounts()
|
||||
accounts = data.get("accounts", [])
|
||||
|
||||
if not accounts:
|
||||
logger.info("No accounts configured.")
|
||||
logger.info("Run 'antigravity auth account add' to add one.")
|
||||
return 0
|
||||
|
||||
logger.info("Configured accounts:\n")
|
||||
for i, account in enumerate(accounts, 1):
|
||||
email = account.get("email", "unknown")
|
||||
enabled = "enabled" if account.get("enabled", True) else "disabled"
|
||||
logger.info(f" {i}. {email} ({enabled})")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_account_remove(args: argparse.Namespace) -> int:
|
||||
"""Remove an account by email."""
|
||||
email = args.email
|
||||
data = load_accounts()
|
||||
accounts = data.get("accounts", [])
|
||||
|
||||
original_len = len(accounts)
|
||||
accounts = [a for a in accounts if a.get("email") != email]
|
||||
|
||||
if len(accounts) == original_len:
|
||||
logger.error(f"No account found with email: {email}")
|
||||
return 1
|
||||
|
||||
data["accounts"] = accounts
|
||||
save_accounts(data)
|
||||
logger.info(f"Removed account: {email}")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Antigravity authentication CLI",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||
|
||||
# auth account add
|
||||
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
|
||||
auth_subparsers = auth_parser.add_subparsers(dest="auth_command")
|
||||
|
||||
account_parser = auth_subparsers.add_parser("account", help="Account management")
|
||||
account_subparsers = account_parser.add_subparsers(dest="account_command")
|
||||
|
||||
add_parser = account_subparsers.add_parser("add", help="Add a new account via OAuth2")
|
||||
add_parser.set_defaults(func=cmd_account_add)
|
||||
|
||||
list_parser = account_subparsers.add_parser("list", help="List configured accounts")
|
||||
list_parser.set_defaults(func=cmd_account_list)
|
||||
|
||||
remove_parser = account_subparsers.add_parser("remove", help="Remove an account")
|
||||
remove_parser.add_argument("email", help="Email of account to remove")
|
||||
remove_parser.set_defaults(func=cmd_account_remove)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
return args.func(args)
|
||||
|
||||
parser.print_help()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
+81
-27
@@ -17,6 +17,7 @@ import http.server
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
import secrets
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -27,6 +28,7 @@ import urllib.parse
|
||||
import urllib.request
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
# OAuth constants (from the Codex CLI binary)
|
||||
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
@@ -165,11 +167,11 @@ def open_browser(url: str) -> bool:
|
||||
if system == "Darwin":
|
||||
subprocess.Popen(["open", url], stdout=devnull, stderr=devnull)
|
||||
elif system == "Windows":
|
||||
subprocess.Popen(["cmd", "/c", "start", url], stdout=devnull, stderr=devnull)
|
||||
os.startfile(url) # type: ignore[attr-defined]
|
||||
else:
|
||||
subprocess.Popen(["xdg-open", url], stdout=devnull, stderr=devnull)
|
||||
return True
|
||||
except OSError:
|
||||
except (AttributeError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
@@ -266,6 +268,71 @@ def parse_manual_input(value: str, expected_state: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _read_manual_input_lines(
|
||||
manual_inputs: queue.Queue[str],
|
||||
stop_event: threading.Event,
|
||||
stdin: TextIO | None = None,
|
||||
) -> None:
|
||||
stream = sys.stdin if stdin is None else stdin
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
manual = stream.readline()
|
||||
except (EOFError, OSError):
|
||||
return
|
||||
|
||||
if not manual:
|
||||
return
|
||||
|
||||
if manual.strip():
|
||||
manual_inputs.put(manual)
|
||||
|
||||
|
||||
def wait_for_code_from_callback_or_stdin(
|
||||
expected_state: str,
|
||||
callback_result: list[str | None],
|
||||
callback_done: threading.Event,
|
||||
timeout_secs: float = 120,
|
||||
poll_interval: float = 0.1,
|
||||
stdin: TextIO | None = None,
|
||||
) -> str | None:
|
||||
manual_inputs: queue.Queue[str] = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
||||
# Read stdin on a daemon thread so manual paste works on platforms where
|
||||
# select() cannot poll console handles, including Windows terminals.
|
||||
threading.Thread(
|
||||
target=_read_manual_input_lines,
|
||||
args=(manual_inputs, stop_event, stdin),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
deadline = time.time() + timeout_secs
|
||||
try:
|
||||
while time.time() < deadline:
|
||||
if callback_result[0]:
|
||||
return callback_result[0]
|
||||
|
||||
while True:
|
||||
try:
|
||||
manual = manual_inputs.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
code = parse_manual_input(manual, expected_state)
|
||||
if code:
|
||||
return code
|
||||
|
||||
if callback_done.is_set():
|
||||
return callback_result[0]
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
return callback_result[0]
|
||||
finally:
|
||||
stop_event.set()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
# Generate PKCE and state
|
||||
verifier, challenge = generate_pkce()
|
||||
@@ -315,41 +382,28 @@ def main() -> int:
|
||||
|
||||
# Start callback server in background
|
||||
callback_result: list[str | None] = [None]
|
||||
callback_done = threading.Event()
|
||||
|
||||
def run_server() -> None:
|
||||
callback_result[0] = wait_for_callback(state, timeout_secs=120)
|
||||
try:
|
||||
callback_result[0] = wait_for_callback(state, timeout_secs=120)
|
||||
finally:
|
||||
callback_done.set()
|
||||
|
||||
server_thread = threading.Thread(target=run_server)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
|
||||
# Also accept manual input in parallel
|
||||
# We poll for both the server result and stdin
|
||||
try:
|
||||
import select
|
||||
|
||||
while server_thread.is_alive():
|
||||
# Check if stdin has data (non-blocking on unix)
|
||||
if hasattr(select, "select"):
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0.5)
|
||||
if ready:
|
||||
manual = sys.stdin.readline()
|
||||
if manual.strip():
|
||||
code = parse_manual_input(manual, state)
|
||||
if code:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.5)
|
||||
|
||||
if callback_result[0]:
|
||||
code = callback_result[0]
|
||||
break
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
code = wait_for_code_from_callback_or_stdin(
|
||||
state,
|
||||
callback_result,
|
||||
callback_done,
|
||||
timeout_secs=120,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\033[0;31mCancelled.\033[0m")
|
||||
return 1
|
||||
|
||||
if not code:
|
||||
code = callback_result[0]
|
||||
else:
|
||||
# Manual paste mode
|
||||
try:
|
||||
|
||||
@@ -79,7 +79,7 @@ async def example_3_config_file():
|
||||
# Copy example config (in practice, you'd place this in your agent folder)
|
||||
import shutil
|
||||
|
||||
shutil.copy("examples/mcp_servers.json", test_agent_path / "mcp_servers.json")
|
||||
shutil.copy(Path(__file__).parent / "mcp_servers.json", test_agent_path / "mcp_servers.json")
|
||||
|
||||
# Load agent - MCP servers will be auto-discovered
|
||||
runner = AgentRunner.load(test_agent_path)
|
||||
|
||||
@@ -16,6 +16,7 @@ after the user picks an account programmatically.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -25,6 +26,7 @@ from framework.graph.checkpoint_config import CheckpointConfig
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.llm import LiteLLMProvider
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
@@ -32,9 +34,13 @@ from framework.runtime.execution_stream import EntryPointSpec
|
||||
from .config import default_config
|
||||
from .nodes import build_tester_node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Goal
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -107,7 +113,11 @@ def _list_aden_accounts() -> list[dict]:
|
||||
for c in integrations
|
||||
if c.status == "active"
|
||||
]
|
||||
except (ImportError, OSError) as exc:
|
||||
logger.debug("Could not list Aden accounts: %s", exc)
|
||||
return []
|
||||
except Exception:
|
||||
logger.warning("Unexpected error listing Aden accounts", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
@@ -119,7 +129,11 @@ def _list_local_accounts() -> list[dict]:
|
||||
return [
|
||||
info.to_account_dict() for info in LocalCredentialRegistry.default().list_accounts()
|
||||
]
|
||||
except ImportError as exc:
|
||||
logger.debug("Local credential registry unavailable: %s", exc)
|
||||
return []
|
||||
except Exception:
|
||||
logger.warning("Unexpected error listing local accounts", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
@@ -140,7 +154,11 @@ def _list_env_fallback_accounts() -> list[dict]:
|
||||
from framework.credentials.storage import EncryptedFileStorage
|
||||
|
||||
encrypted_ids: set[str] = set(EncryptedFileStorage().list_all())
|
||||
except (ImportError, OSError) as exc:
|
||||
logger.debug("Could not read encrypted store: %s", exc)
|
||||
encrypted_ids = set()
|
||||
except Exception:
|
||||
logger.warning("Unexpected error reading encrypted store", exc_info=True)
|
||||
encrypted_ids = set()
|
||||
|
||||
def _is_configured(cred_name: str, spec) -> bool:
|
||||
@@ -300,8 +318,10 @@ def _activate_local_account(credential_id: str, alias: str) -> None:
|
||||
|
||||
if key:
|
||||
os.environ[spec.env_var] = key
|
||||
except (ImportError, KeyError, OSError) as exc:
|
||||
logger.debug("Could not inject credentials: %s", exc)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Unexpected error injecting credentials", exc_info=True)
|
||||
|
||||
|
||||
def _configure_aden_node(
|
||||
@@ -563,6 +583,23 @@ class CredentialTesterAgent:
|
||||
if mcp_config_path.exists():
|
||||
self._tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
try:
|
||||
agent_dir = Path(__file__).parent
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
if (agent_dir / "mcp_registry.json").is_file():
|
||||
self._tool_registry.set_mcp_registry_agent_path(agent_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(agent_dir)
|
||||
if registry_configs:
|
||||
self._tool_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("MCP registry config failed to load", exc_info=True)
|
||||
|
||||
extra_kwargs = getattr(self.config, "extra_kwargs", {}) or {}
|
||||
llm = LiteLLMProvider(
|
||||
model=self.config.model,
|
||||
|
||||
@@ -702,6 +702,15 @@ stop_worker() to return to STAGING phase.
|
||||
_queen_behavior_always = """
|
||||
# Behavior
|
||||
|
||||
## Images attached by the user
|
||||
|
||||
Users can attach images directly to their chat messages. When you see an \
|
||||
image in the conversation, analyze it using your native vision capability — \
|
||||
do NOT say you cannot see images or that you lack access to files. The image \
|
||||
is embedded in the message; no tool call is needed to view it. Describe what \
|
||||
you see, answer questions about it, and use the visual content to inform your \
|
||||
response just as you would text.
|
||||
|
||||
## CRITICAL RULE — ask_user / ask_user_multiple
|
||||
|
||||
Every response that ends with a question, a prompt, or expects user \
|
||||
|
||||
@@ -31,6 +31,11 @@ def _queen_dir() -> Path:
|
||||
return Path.home() / ".hive" / "queen"
|
||||
|
||||
|
||||
def format_memory_date(d: date) -> str:
|
||||
"""Return a cross-platform long date label without a zero-padded day."""
|
||||
return f"{d.strftime('%B')} {d.day}, {d.year}"
|
||||
|
||||
|
||||
def semantic_memory_path() -> Path:
|
||||
return _queen_dir() / "MEMORY.md"
|
||||
|
||||
@@ -91,9 +96,9 @@ def format_for_injection() -> str:
|
||||
content = content[:_EPISODIC_CHAR_BUDGET] + "\n\n…(truncated)"
|
||||
today = date.today()
|
||||
if d == today:
|
||||
label = f"## Today — {d.strftime('%B %-d, %Y')}"
|
||||
label = f"## Today — {format_memory_date(d)}"
|
||||
else:
|
||||
label = f"## {d.strftime('%B %-d, %Y')}"
|
||||
label = f"## {format_memory_date(d)}"
|
||||
parts.append(f"{label}\n\n{content}")
|
||||
|
||||
if not parts:
|
||||
@@ -127,7 +132,7 @@ def append_episodic_entry(content: str) -> None:
|
||||
ep_path = episodic_memory_path()
|
||||
ep_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
today = date.today()
|
||||
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
|
||||
today_str = format_memory_date(today)
|
||||
timestamp = datetime.now().strftime("%H:%M")
|
||||
if not ep_path.exists():
|
||||
header = f"# {today_str}\n\n"
|
||||
@@ -226,7 +231,11 @@ def read_session_context(session_dir: Path, max_messages: int = 80) -> str:
|
||||
elif content:
|
||||
label = "user" if role == "user" else "queen"
|
||||
lines.append(f"[{label}]: {content[:600]}")
|
||||
except (KeyError, TypeError) as exc:
|
||||
logger.debug("Skipping malformed conversation message: %s", exc)
|
||||
continue
|
||||
except Exception:
|
||||
logger.warning("Unexpected error parsing conversation message", exc_info=True)
|
||||
continue
|
||||
if lines:
|
||||
parts.append("## Conversation\n\n" + "\n".join(lines))
|
||||
@@ -327,7 +336,7 @@ async def consolidate_queen_memory(
|
||||
existing_semantic = read_semantic_memory()
|
||||
today_journal = read_episodic_memory()
|
||||
today = date.today()
|
||||
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
|
||||
today_str = format_memory_date(today)
|
||||
adapt_path = session_dir / "data" / "adapt.md"
|
||||
|
||||
user_msg = (
|
||||
@@ -395,5 +404,5 @@ async def consolidate_queen_memory(
|
||||
f"session: {session_id}\ntime: {datetime.now().isoformat()}\n\n{tb}",
|
||||
encoding="utf-8",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except OSError:
|
||||
pass # Cannot write error file; original exception already logged
|
||||
|
||||
@@ -150,7 +150,7 @@ Call all three subagents in a single response to run them in parallel:
|
||||
|
||||
## GCU Anti-Patterns
|
||||
|
||||
- Using `browser_screenshot` to read text (use `browser_snapshot`)
|
||||
- Using `browser_screenshot` to read text (use `browser_snapshot` instead; screenshots are for visual context only)
|
||||
- Re-navigating after scrolling (resets scroll position)
|
||||
- Attempting login on auth walls
|
||||
- Forgetting `target_id` in multi-tab scenarios
|
||||
|
||||
@@ -89,6 +89,21 @@ def main():
|
||||
|
||||
register_testing_commands(subparsers)
|
||||
|
||||
# Register skill commands (skill list, skill trust, ...)
|
||||
from framework.skills.cli import register_skill_commands
|
||||
|
||||
register_skill_commands(subparsers)
|
||||
|
||||
# Register debugger commands (debugger)
|
||||
from framework.debugger.cli import register_debugger_commands
|
||||
|
||||
register_debugger_commands(subparsers)
|
||||
|
||||
# Register MCP registry commands (mcp install, mcp add, ...)
|
||||
from framework.runner.mcp_registry_cli import register_mcp_commands
|
||||
|
||||
register_mcp_commands(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
|
||||
+258
-2
@@ -51,16 +51,169 @@ def get_preferred_model() -> str:
|
||||
"""Return the user's preferred LLM model string (e.g. 'anthropic/claude-sonnet-4-20250514')."""
|
||||
llm = get_hive_config().get("llm", {})
|
||||
if llm.get("provider") and llm.get("model"):
|
||||
return f"{llm['provider']}/{llm['model']}"
|
||||
provider = str(llm["provider"])
|
||||
model = str(llm["model"]).strip()
|
||||
# OpenRouter quickstart stores raw model IDs; tolerate pasted "openrouter/<id>" too.
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return "anthropic/claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
def get_preferred_worker_model() -> str | None:
|
||||
"""Return the user's preferred worker LLM model, or None if not configured.
|
||||
|
||||
Reads from the ``worker_llm`` section of ~/.hive/configuration.json.
|
||||
Returns None when no worker-specific model is set, so callers can
|
||||
fall back to the default (queen) model via ``get_preferred_model()``.
|
||||
"""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm.get("provider") and worker_llm.get("model"):
|
||||
provider = str(worker_llm["provider"])
|
||||
model = str(worker_llm["model"]).strip()
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_api_key() -> str | None:
|
||||
"""Return the API key for the worker LLM, falling back to the default key."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_key()
|
||||
|
||||
# Worker-specific subscription / env var
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_claude_code_token
|
||||
|
||||
token = get_claude_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_codex_token
|
||||
|
||||
token = get_codex_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_kimi_code_token
|
||||
|
||||
token = get_kimi_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_antigravity_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_antigravity_token
|
||||
|
||||
token = get_antigravity_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
api_key_env_var = worker_llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
|
||||
# Fall back to default key
|
||||
return get_api_key()
|
||||
|
||||
|
||||
def get_worker_api_base() -> str | None:
|
||||
"""Return the api_base for the worker LLM, falling back to the default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_base()
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
return "https://chatgpt.com/backend-api/codex"
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
return "https://api.kimi.com/coding"
|
||||
if worker_llm.get("use_antigravity_subscription"):
|
||||
# Antigravity uses AntigravityProvider directly — no api_base needed.
|
||||
return None
|
||||
if worker_llm.get("api_base"):
|
||||
return worker_llm["api_base"]
|
||||
if str(worker_llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_llm_extra_kwargs() -> dict[str, Any]:
|
||||
"""Return extra kwargs for the worker LLM provider."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_llm_extra_kwargs()
|
||||
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
return {
|
||||
"extra_headers": {"authorization": f"Bearer {api_key}"},
|
||||
}
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
try:
|
||||
from framework.runner.runner import get_codex_account_id
|
||||
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
headers["ChatGPT-Account-Id"] = account_id
|
||||
except ImportError:
|
||||
pass
|
||||
return {
|
||||
"extra_headers": headers,
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
if worker_llm.get("provider") == "ollama":
|
||||
return {"num_ctx": worker_llm.get("num_ctx", 16384)}
|
||||
return {}
|
||||
|
||||
|
||||
def get_worker_max_tokens() -> int:
|
||||
"""Return max_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_tokens" in worker_llm:
|
||||
return worker_llm["max_tokens"]
|
||||
return get_max_tokens()
|
||||
|
||||
|
||||
def get_worker_max_context_tokens() -> int:
|
||||
"""Return max_context_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_context_tokens" in worker_llm:
|
||||
return worker_llm["max_context_tokens"]
|
||||
return get_max_context_tokens()
|
||||
|
||||
|
||||
def get_max_tokens() -> int:
|
||||
"""Return the configured max_tokens, falling back to DEFAULT_MAX_TOKENS."""
|
||||
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
|
||||
|
||||
DEFAULT_MAX_CONTEXT_TOKENS = 32_000
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def get_max_context_tokens() -> int:
|
||||
@@ -113,6 +266,17 @@ def get_api_key() -> str | None:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Antigravity subscription: read OAuth token from accounts JSON
|
||||
if llm.get("use_antigravity_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_antigravity_token
|
||||
|
||||
token = get_antigravity_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Standard env-var path (covers ZAI Code and all API-key providers)
|
||||
api_key_env_var = llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
@@ -120,6 +284,86 @@ def get_api_key() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
# OAuth credentials for Antigravity are fetched from the opencode-antigravity-auth project.
|
||||
# This project reverse-engineered and published the public OAuth credentials
|
||||
# for Google's Antigravity/Cloud Code Assist API.
|
||||
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
|
||||
_ANTIGRAVITY_CREDENTIALS_URL = (
|
||||
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
|
||||
)
|
||||
_antigravity_credentials_cache: tuple[str | None, str | None] = (None, None)
|
||||
|
||||
|
||||
def _fetch_antigravity_credentials() -> tuple[str | None, str | None]:
|
||||
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
|
||||
global _antigravity_credentials_cache
|
||||
if _antigravity_credentials_cache[0] and _antigravity_credentials_cache[1]:
|
||||
return _antigravity_credentials_cache
|
||||
|
||||
import re
|
||||
import urllib.request
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
_ANTIGRAVITY_CREDENTIALS_URL, headers={"User-Agent": "Hive/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
content = resp.read().decode("utf-8")
|
||||
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
|
||||
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
|
||||
client_id = id_match.group(1) if id_match else None
|
||||
client_secret = secret_match.group(1) if secret_match else None
|
||||
if client_id and client_secret:
|
||||
_antigravity_credentials_cache = (client_id, client_secret)
|
||||
return client_id, client_secret
|
||||
except Exception as e:
|
||||
logger.debug("Failed to fetch Antigravity credentials from public source: %s", e)
|
||||
return None, None
|
||||
|
||||
|
||||
def get_antigravity_client_id() -> str:
|
||||
"""Return the Antigravity OAuth application client ID.
|
||||
|
||||
Checked in order:
|
||||
1. ``ANTIGRAVITY_CLIENT_ID`` environment variable
|
||||
2. ``llm.antigravity_client_id`` in ~/.hive/configuration.json
|
||||
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
|
||||
"""
|
||||
env = os.environ.get("ANTIGRAVITY_CLIENT_ID")
|
||||
if env:
|
||||
return env
|
||||
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_id")
|
||||
if cfg_val:
|
||||
return cfg_val
|
||||
# Fetch from public source
|
||||
client_id, _ = _fetch_antigravity_credentials()
|
||||
if client_id:
|
||||
return client_id
|
||||
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
|
||||
|
||||
|
||||
def get_antigravity_client_secret() -> str | None:
|
||||
"""Return the Antigravity OAuth client secret.
|
||||
|
||||
Checked in order:
|
||||
1. ``ANTIGRAVITY_CLIENT_SECRET`` environment variable
|
||||
2. ``llm.antigravity_client_secret`` in ~/.hive/configuration.json
|
||||
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
|
||||
|
||||
Returns None when not found — token refresh will be skipped and
|
||||
the caller must use whatever access token is already available.
|
||||
"""
|
||||
env = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
|
||||
if env:
|
||||
return env
|
||||
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_secret") or None
|
||||
if cfg_val:
|
||||
return cfg_val
|
||||
# Fetch from public source
|
||||
_, secret = _fetch_antigravity_credentials()
|
||||
return secret
|
||||
|
||||
|
||||
def get_gcu_enabled() -> bool:
|
||||
"""Return whether GCU (browser automation) is enabled in user config."""
|
||||
return get_hive_config().get("gcu_enabled", True)
|
||||
@@ -142,7 +386,14 @@ def get_api_base() -> str | None:
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
return "https://api.kimi.com/coding"
|
||||
return llm.get("api_base")
|
||||
if llm.get("use_antigravity_subscription"):
|
||||
# Antigravity uses AntigravityProvider directly — no api_base needed.
|
||||
return None
|
||||
if llm.get("api_base"):
|
||||
return llm["api_base"]
|
||||
if str(llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_extra_kwargs() -> dict[str, Any]:
|
||||
@@ -183,6 +434,11 @@ def get_llm_extra_kwargs() -> dict[str, Any]:
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
if llm.get("provider") == "ollama":
|
||||
# Pass num_ctx to Ollama so it doesn't silently truncate the ~9.5k Queen prompt.
|
||||
# Ollama's default num_ctx is only 2048. We set it to 16384 here so LiteLLM
|
||||
# passes it through as a provider-specific option.
|
||||
return {"num_ctx": llm.get("num_ctx", 16384)}
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
@@ -37,6 +38,8 @@ from typing import TYPE_CHECKING, Any
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph import NodeSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ANSI colors for terminal output
|
||||
class Colors:
|
||||
@@ -365,8 +368,11 @@ class CredentialSetupSession:
|
||||
self._print("")
|
||||
try:
|
||||
api_key = self.password_fn(f"Paste your {cred.env_var}: ").strip()
|
||||
except (EOFError, OSError) as exc:
|
||||
logger.debug("Password input unavailable, falling back to plain input: %s", exc)
|
||||
api_key = self._input(f"Paste your {cred.env_var}: ").strip()
|
||||
except Exception:
|
||||
# Fallback to regular input if password input fails
|
||||
logger.warning("Unexpected error reading password input", exc_info=True)
|
||||
api_key = self._input(f"Paste your {cred.env_var}: ").strip()
|
||||
|
||||
if not api_key:
|
||||
@@ -403,7 +409,11 @@ class CredentialSetupSession:
|
||||
|
||||
try:
|
||||
aden_key = self.password_fn("Paste your ADEN_API_KEY: ").strip()
|
||||
except (EOFError, OSError) as exc:
|
||||
logger.debug("Password input unavailable for ADEN_API_KEY: %s", exc)
|
||||
aden_key = self._input("Paste your ADEN_API_KEY: ").strip()
|
||||
except Exception:
|
||||
logger.warning("Unexpected error reading ADEN_API_KEY input", exc_info=True)
|
||||
aden_key = self._input("Paste your ADEN_API_KEY: ").strip()
|
||||
|
||||
if not aden_key:
|
||||
@@ -433,8 +443,10 @@ class CredentialSetupSession:
|
||||
value = store.get_key(cred_id, cred.credential_key)
|
||||
if value:
|
||||
os.environ[cred.env_var] = value
|
||||
except (KeyError, OSError) as exc:
|
||||
logger.debug("Could not export credential to env: %s", exc)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Unexpected error exporting credential to env", exc_info=True)
|
||||
return True
|
||||
else:
|
||||
self._print(
|
||||
@@ -457,9 +469,12 @@ class CredentialSetupSession:
|
||||
"message": result.message,
|
||||
"details": result.details,
|
||||
}
|
||||
except Exception:
|
||||
except ImportError:
|
||||
# No health checker available
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning("Health check failed for %s", cred.credential_name, exc_info=True)
|
||||
return None
|
||||
|
||||
def _store_credential(self, cred: MissingCredential, value: str) -> None:
|
||||
"""Store credential in encrypted store and export to env."""
|
||||
@@ -561,7 +576,11 @@ def _load_nodes_from_python_agent(agent_path: Path) -> list:
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return getattr(module, "nodes", [])
|
||||
except (ImportError, OSError) as exc:
|
||||
logger.debug("Could not load agent module: %s", exc)
|
||||
return []
|
||||
except Exception:
|
||||
logger.warning("Unexpected error loading agent module", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
@@ -588,7 +607,11 @@ def _load_nodes_from_json_agent(agent_json: Path) -> list:
|
||||
)
|
||||
)
|
||||
return nodes
|
||||
except (json.JSONDecodeError, KeyError, OSError) as exc:
|
||||
logger.debug("Could not load JSON agent: %s", exc)
|
||||
return []
|
||||
except Exception:
|
||||
logger.warning("Unexpected error loading JSON agent", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -51,6 +51,16 @@ def ensure_credential_key_env() -> None:
|
||||
if found and value:
|
||||
os.environ[var_name] = value
|
||||
logger.debug("Loaded %s from shell config", var_name)
|
||||
# Also load the currently configured LLM env var even if it's not in CREDENTIAL_SPECS.
|
||||
# This keeps quickstart-written keys available to fresh processes on Unix shells.
|
||||
from framework.config import get_hive_config
|
||||
|
||||
llm_env_var = str(get_hive_config().get("llm", {}).get("api_key_env_var", "")).strip()
|
||||
if llm_env_var and not os.environ.get(llm_env_var):
|
||||
found, value = check_env_var_in_shell_config(llm_env_var)
|
||||
if found and value:
|
||||
os.environ[llm_env_var] = value
|
||||
logger.debug("Loaded configured LLM env var %s from shell config", llm_env_var)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
"""CLI command for the LLM debug log viewer."""
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_SCRIPT = Path(__file__).resolve().parents[3] / "scripts" / "llm_debug_log_visualizer.py"
|
||||
|
||||
|
||||
def register_debugger_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
"""Register the ``hive debugger`` command."""
|
||||
parser = subparsers.add_parser(
|
||||
"debugger",
|
||||
help="Open the LLM debug log viewer",
|
||||
description=(
|
||||
"Start a local server that lets you browse LLM debug sessions "
|
||||
"recorded in ~/.hive/llm_logs. Sessions are loaded on demand so "
|
||||
"the browser stays responsive."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--session",
|
||||
help="Execution ID to select initially.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port for the local server (0 = auto-pick a free port).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
help="Directory containing JSONL log files (default: ~/.hive/llm_logs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit-files",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of newest log files to scan (default: 200).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
help="Write a static HTML file instead of starting a server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-open",
|
||||
action="store_true",
|
||||
help="Start the server but do not open a browser.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-tests",
|
||||
action="store_true",
|
||||
help="Show test/mock sessions (hidden by default).",
|
||||
)
|
||||
parser.set_defaults(func=cmd_debugger)
|
||||
|
||||
|
||||
def cmd_debugger(args: argparse.Namespace) -> int:
|
||||
"""Launch the LLM debug log visualizer."""
|
||||
cmd: list[str] = [sys.executable, str(_SCRIPT)]
|
||||
if args.session:
|
||||
cmd += ["--session", args.session]
|
||||
if args.port:
|
||||
cmd += ["--port", str(args.port)]
|
||||
if args.logs_dir:
|
||||
cmd += ["--logs-dir", args.logs_dir]
|
||||
if args.limit_files is not None:
|
||||
cmd += ["--limit-files", str(args.limit_files)]
|
||||
if args.output:
|
||||
cmd += ["--output", args.output]
|
||||
if args.no_open:
|
||||
cmd.append("--no-open")
|
||||
if args.include_tests:
|
||||
cmd.append("--include-tests")
|
||||
return subprocess.call(cmd)
|
||||
@@ -33,10 +33,20 @@ class Message:
|
||||
is_transition_marker: bool = False
|
||||
# True when this message is real human input (from /chat), not a system prompt
|
||||
is_client_input: bool = False
|
||||
# Optional image content blocks (e.g. from browser_screenshot)
|
||||
image_content: list[dict[str, Any]] | None = None
|
||||
# True when message contains an activated skill body (AS-10: never prune)
|
||||
is_skill_content: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
if self.role == "user":
|
||||
if self.image_content:
|
||||
blocks: list[dict[str, Any]] = []
|
||||
if self.content:
|
||||
blocks.append({"type": "text", "text": self.content})
|
||||
blocks.extend(self.image_content)
|
||||
return {"role": "user", "content": blocks}
|
||||
return {"role": "user", "content": self.content}
|
||||
|
||||
if self.role == "assistant":
|
||||
@@ -47,6 +57,15 @@ class Message:
|
||||
|
||||
# role == "tool"
|
||||
content = f"ERROR: {self.content}" if self.is_error else self.content
|
||||
if self.image_content:
|
||||
# Multimodal tool result: text + image content blocks
|
||||
blocks: list[dict[str, Any]] = [{"type": "text", "text": content}]
|
||||
blocks.extend(self.image_content)
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_use_id,
|
||||
"content": blocks,
|
||||
}
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_use_id,
|
||||
@@ -72,6 +91,8 @@ class Message:
|
||||
d["is_transition_marker"] = self.is_transition_marker
|
||||
if self.is_client_input:
|
||||
d["is_client_input"] = self.is_client_input
|
||||
if self.image_content is not None:
|
||||
d["image_content"] = self.image_content
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
@@ -87,6 +108,7 @@ class Message:
|
||||
phase_id=data.get("phase_id"),
|
||||
is_transition_marker=data.get("is_transition_marker", False),
|
||||
is_client_input=data.get("is_client_input", False),
|
||||
image_content=data.get("image_content"),
|
||||
)
|
||||
|
||||
|
||||
@@ -373,6 +395,7 @@ class NodeConversation:
|
||||
*,
|
||||
is_transition_marker: bool = False,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -381,6 +404,7 @@ class NodeConversation:
|
||||
phase_id=self._current_phase,
|
||||
is_transition_marker=is_transition_marker,
|
||||
is_client_input=is_client_input,
|
||||
image_content=image_content,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -409,6 +433,8 @@ class NodeConversation:
|
||||
tool_use_id: str,
|
||||
content: str,
|
||||
is_error: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
is_skill_content: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -417,6 +443,8 @@ class NodeConversation:
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
phase_id=self._current_phase,
|
||||
image_content=image_content,
|
||||
is_skill_content=is_skill_content,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -610,6 +638,8 @@ class NodeConversation:
|
||||
continue
|
||||
if msg.is_error:
|
||||
continue # never prune errors
|
||||
if msg.is_skill_content:
|
||||
continue # never prune activated skill instructions (AS-10)
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
# Tiny results (set_output acks, confirmations) — pruning
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""EventLoopNode subpackage — modular components of the event loop orchestrator.
|
||||
|
||||
All public symbols are re-exported by the parent ``event_loop_node.py`` for
|
||||
backward compatibility. Internal consumers may import directly from these
|
||||
submodules for clarity.
|
||||
"""
|
||||
@@ -0,0 +1,652 @@
|
||||
"""Conversation compaction pipeline.
|
||||
|
||||
Implements the multi-level compaction strategy:
|
||||
1. Prune old tool results
|
||||
2. Structure-preserving compaction (spillover)
|
||||
3. LLM summary compaction (with recursive splitting)
|
||||
4. Emergency deterministic summary (no LLM)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop.event_publishing import publish_context_usage
|
||||
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator
|
||||
from framework.graph.node import NodeContext
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limits for LLM compaction
|
||||
LLM_COMPACT_CHAR_LIMIT: int = 240_000
|
||||
LLM_COMPACT_MAX_DEPTH: int = 10
|
||||
|
||||
|
||||
async def compact(
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator | None,
|
||||
*,
|
||||
config: LoopConfig,
|
||||
event_bus: EventBus | None,
|
||||
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
|
||||
max_depth: int = LLM_COMPACT_MAX_DEPTH,
|
||||
) -> None:
|
||||
"""Run the full compaction pipeline if conversation needs compaction.
|
||||
|
||||
Pipeline stages (in order, short-circuits when budget is restored):
|
||||
1. Prune old tool results
|
||||
2. Structure-preserving compaction (free, no LLM)
|
||||
3. LLM summary compaction (recursive split if too large)
|
||||
4. Emergency deterministic summary (fallback)
|
||||
"""
|
||||
ratio_before = conversation.usage_ratio()
|
||||
phase_grad = getattr(ctx, "continuous_mode", False)
|
||||
pre_inventory: list[dict[str, Any]] | None = None
|
||||
|
||||
if ratio_before >= 1.0:
|
||||
pre_inventory = build_message_inventory(conversation)
|
||||
|
||||
# --- Step 1: Prune old tool results (free, fast) ---
|
||||
protect = max(2000, config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
protect_tokens=protect,
|
||||
min_prune_tokens=max(1000, protect // 3),
|
||||
)
|
||||
if pruned > 0:
|
||||
logger.info(
|
||||
"Pruned %d old tool results: %.0f%% -> %.0f%%",
|
||||
pruned,
|
||||
ratio_before * 100,
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
return
|
||||
|
||||
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
|
||||
spill_dir = config.spillover_dir
|
||||
if spill_dir:
|
||||
await conversation.compact_preserving_structure(
|
||||
spillover_dir=spill_dir,
|
||||
keep_recent=4,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
return
|
||||
|
||||
# --- Step 3: LLM summary compaction ---
|
||||
if ctx.llm is not None:
|
||||
logger.info(
|
||||
"LLM summary compaction triggered (%.0f%% usage)",
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
try:
|
||||
summary = await llm_compact(
|
||||
ctx,
|
||||
list(conversation.messages),
|
||||
accumulator,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=config.max_context_tokens,
|
||||
)
|
||||
await conversation.compact(
|
||||
summary,
|
||||
keep_recent=2,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("LLM compaction failed: %s", e)
|
||||
|
||||
if not conversation.needs_compaction():
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
return
|
||||
|
||||
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
|
||||
logger.warning(
|
||||
"Emergency compaction (%.0f%% usage)",
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
summary = build_emergency_summary(ctx, accumulator, conversation, config)
|
||||
await conversation.compact(
|
||||
summary,
|
||||
keep_recent=1,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
|
||||
|
||||
# --- LLM compaction with binary-search splitting ----------------------
|
||||
|
||||
|
||||
async def llm_compact(
|
||||
ctx: NodeContext,
|
||||
messages: list,
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
_depth: int = 0,
|
||||
*,
|
||||
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
|
||||
max_depth: int = LLM_COMPACT_MAX_DEPTH,
|
||||
max_context_tokens: int = 128_000,
|
||||
) -> str:
|
||||
"""Summarise *messages* with LLM, splitting recursively if too large.
|
||||
|
||||
If the formatted text exceeds ``LLM_COMPACT_CHAR_LIMIT`` or the LLM
|
||||
rejects the call with a context-length error, the messages are split
|
||||
in half and each half is summarised independently. Tool history is
|
||||
appended once at the top-level call (``_depth == 0``).
|
||||
"""
|
||||
from framework.graph.conversation import extract_tool_call_history
|
||||
from framework.graph.event_loop.tool_result_handler import is_context_too_large_error
|
||||
|
||||
if _depth > max_depth:
|
||||
raise RuntimeError(f"LLM compaction recursion limit ({max_depth})")
|
||||
|
||||
formatted = format_messages_for_summary(messages)
|
||||
|
||||
# Proactive split: avoid wasting an API call on oversized input
|
||||
if len(formatted) > char_limit and len(messages) > 1:
|
||||
summary = await _llm_compact_split(
|
||||
ctx,
|
||||
messages,
|
||||
accumulator,
|
||||
_depth,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
else:
|
||||
prompt = build_llm_compaction_prompt(
|
||||
ctx,
|
||||
accumulator,
|
||||
formatted,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
summary_budget = max(1024, max_context_tokens // 2)
|
||||
try:
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system=(
|
||||
"You are a conversation compactor for an AI agent. "
|
||||
"Write a detailed summary that allows the agent to "
|
||||
"continue its work. Preserve user-stated rules, "
|
||||
"constraints, and account/identity preferences verbatim."
|
||||
),
|
||||
max_tokens=summary_budget,
|
||||
)
|
||||
summary = response.content
|
||||
except Exception as e:
|
||||
if is_context_too_large_error(e) and len(messages) > 1:
|
||||
logger.info(
|
||||
"LLM context too large (depth=%d, msgs=%d) — splitting",
|
||||
_depth,
|
||||
len(messages),
|
||||
)
|
||||
summary = await _llm_compact_split(
|
||||
ctx,
|
||||
messages,
|
||||
accumulator,
|
||||
_depth,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Append tool history at top level only
|
||||
if _depth == 0:
|
||||
tool_history = extract_tool_call_history(messages)
|
||||
if tool_history and "TOOLS ALREADY CALLED" not in summary:
|
||||
summary += "\n\n" + tool_history
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _llm_compact_split(
|
||||
ctx: NodeContext,
|
||||
messages: list,
|
||||
accumulator: OutputAccumulator | None,
|
||||
_depth: int,
|
||||
*,
|
||||
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
|
||||
max_depth: int = LLM_COMPACT_MAX_DEPTH,
|
||||
max_context_tokens: int = 128_000,
|
||||
) -> str:
|
||||
"""Split messages in half and summarise each half independently."""
|
||||
mid = max(1, len(messages) // 2)
|
||||
s1 = await llm_compact(
|
||||
ctx,
|
||||
messages[:mid],
|
||||
None,
|
||||
_depth + 1,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
s2 = await llm_compact(
|
||||
ctx,
|
||||
messages[mid:],
|
||||
accumulator,
|
||||
_depth + 1,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
return s1 + "\n\n" + s2
|
||||
|
||||
|
||||
# --- Compaction helpers ------------------------------------------------
|
||||
|
||||
|
||||
def format_messages_for_summary(messages: list) -> str:
|
||||
"""Format messages as text for LLM summarisation."""
|
||||
lines: list[str] = []
|
||||
for m in messages:
|
||||
if m.role == "tool":
|
||||
content = m.content[:500]
|
||||
if len(m.content) > 500:
|
||||
content += "..."
|
||||
lines.append(f"[tool result]: {content}")
|
||||
elif m.role == "assistant" and m.tool_calls:
|
||||
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
|
||||
text = m.content[:200] if m.content else ""
|
||||
lines.append(f"[assistant (calls: {', '.join(names)})]: {text}")
|
||||
else:
|
||||
lines.append(f"[{m.role}]: {m.content}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def build_llm_compaction_prompt(
|
||||
ctx: NodeContext,
|
||||
accumulator: OutputAccumulator | None,
|
||||
formatted_messages: str,
|
||||
*,
|
||||
max_context_tokens: int = 128_000,
|
||||
) -> str:
|
||||
"""Build prompt for LLM compaction targeting 50% of token budget."""
|
||||
spec = ctx.node_spec
|
||||
ctx_lines = [f"NODE: {spec.name} (id={spec.id})"]
|
||||
if spec.description:
|
||||
ctx_lines.append(f"PURPOSE: {spec.description}")
|
||||
if spec.success_criteria:
|
||||
ctx_lines.append(f"SUCCESS CRITERIA: {spec.success_criteria}")
|
||||
|
||||
if accumulator:
|
||||
acc = accumulator.to_dict()
|
||||
done = {k: v for k, v in acc.items() if v is not None}
|
||||
todo = [k for k, v in acc.items() if v is None]
|
||||
if done:
|
||||
ctx_lines.append(
|
||||
"OUTPUTS ALREADY SET:\n"
|
||||
+ "\n".join(f" {k}: {str(v)[:150]}" for k, v in done.items())
|
||||
)
|
||||
if todo:
|
||||
ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(todo)}")
|
||||
elif spec.output_keys:
|
||||
ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
|
||||
|
||||
target_tokens = max_context_tokens // 2
|
||||
target_chars = target_tokens * 4
|
||||
node_ctx = "\n".join(ctx_lines)
|
||||
|
||||
return (
|
||||
"You are compacting an AI agent's conversation history. "
|
||||
"The agent is still working and needs to continue.\n\n"
|
||||
f"AGENT CONTEXT:\n{node_ctx}\n\n"
|
||||
f"CONVERSATION MESSAGES:\n{formatted_messages}\n\n"
|
||||
"INSTRUCTIONS:\n"
|
||||
f"Write a summary of approximately {target_chars} characters "
|
||||
f"(~{target_tokens} tokens).\n"
|
||||
"1. Preserve ALL user-stated rules, constraints, and preferences "
|
||||
"verbatim.\n"
|
||||
"2. Preserve key decisions made and results obtained.\n"
|
||||
"3. Preserve in-progress work state so the agent can continue.\n"
|
||||
"4. Be detailed enough that the agent can resume without "
|
||||
"re-doing work.\n"
|
||||
)
|
||||
|
||||
|
||||
def build_message_inventory(conversation: NodeConversation) -> list[dict[str, Any]]:
|
||||
"""Build a per-message size inventory for debug logging."""
|
||||
inventory: list[dict[str, Any]] = []
|
||||
for message in conversation.messages:
|
||||
content_chars = len(message.content)
|
||||
tool_call_args_chars = 0
|
||||
tool_name = None
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
args = tool_call.get("function", {}).get("arguments", "")
|
||||
tool_call_args_chars += (
|
||||
len(args) if isinstance(args, str) else len(json.dumps(args))
|
||||
)
|
||||
names = [
|
||||
tool_call.get("function", {}).get("name", "?") for tool_call in message.tool_calls
|
||||
]
|
||||
tool_name = ", ".join(names)
|
||||
elif message.role == "tool" and message.tool_use_id:
|
||||
for previous in conversation.messages:
|
||||
if previous.tool_calls:
|
||||
for tool_call in previous.tool_calls:
|
||||
if tool_call.get("id") == message.tool_use_id:
|
||||
tool_name = tool_call.get("function", {}).get("name", "?")
|
||||
break
|
||||
if tool_name:
|
||||
break
|
||||
entry: dict[str, Any] = {
|
||||
"seq": message.seq,
|
||||
"role": message.role,
|
||||
"content_chars": content_chars,
|
||||
}
|
||||
if tool_call_args_chars:
|
||||
entry["tool_call_args_chars"] = tool_call_args_chars
|
||||
if tool_name:
|
||||
entry["tool"] = tool_name
|
||||
if message.is_error:
|
||||
entry["is_error"] = True
|
||||
if message.phase_id:
|
||||
entry["phase"] = message.phase_id
|
||||
if content_chars > 2000:
|
||||
entry["preview"] = message.content[:200] + "…"
|
||||
inventory.append(entry)
|
||||
return inventory
|
||||
|
||||
|
||||
def write_compaction_debug_log(
|
||||
ctx: NodeContext,
|
||||
before_pct: int,
|
||||
after_pct: int,
|
||||
level: str,
|
||||
inventory: list[dict[str, Any]] | None,
|
||||
) -> None:
|
||||
"""Write detailed compaction analysis to ~/.hive/compaction_log/."""
|
||||
log_dir = Path.home() / ".hive" / "compaction_log"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
|
||||
node_label = ctx.node_id.replace("/", "_")
|
||||
log_path = log_dir / f"{ts}_{node_label}.md"
|
||||
|
||||
lines: list[str] = [
|
||||
f"# Compaction Debug — {ctx.node_id}",
|
||||
f"**Time:** {datetime.now(UTC).isoformat()}",
|
||||
f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)",
|
||||
]
|
||||
if ctx.stream_id:
|
||||
lines.append(f"**Stream:** {ctx.stream_id}")
|
||||
lines.append(f"**Level:** {level}")
|
||||
lines.append(f"**Usage:** {before_pct}% → {after_pct}%")
|
||||
lines.append("")
|
||||
|
||||
if inventory:
|
||||
total_chars = sum(
|
||||
entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
|
||||
for entry in inventory
|
||||
)
|
||||
lines.append(
|
||||
"## Pre-Compaction Message Inventory "
|
||||
f"({len(inventory)} messages, {total_chars:,} total chars)"
|
||||
)
|
||||
lines.append("")
|
||||
ranked = sorted(
|
||||
inventory,
|
||||
key=lambda entry: entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0),
|
||||
reverse=True,
|
||||
)
|
||||
lines.append("| # | seq | role | tool | chars | % of total | flags |")
|
||||
lines.append("|---|-----|------|------|------:|------------|-------|")
|
||||
for i, entry in enumerate(ranked, 1):
|
||||
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
|
||||
pct = (chars / total_chars * 100) if total_chars else 0
|
||||
tool = entry.get("tool", "")
|
||||
flags: list[str] = []
|
||||
if entry.get("is_error"):
|
||||
flags.append("error")
|
||||
if entry.get("phase"):
|
||||
flags.append(f"phase={entry['phase']}")
|
||||
lines.append(
|
||||
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
|
||||
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
|
||||
)
|
||||
|
||||
large = [entry for entry in ranked if entry.get("preview")]
|
||||
if large:
|
||||
lines.append("")
|
||||
lines.append("### Large message previews")
|
||||
for entry in large:
|
||||
lines.append(
|
||||
f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):"
|
||||
)
|
||||
lines.append(f"```\n{entry['preview']}\n```")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
log_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
logger.debug("Compaction debug log written to %s", log_path)
|
||||
except OSError:
|
||||
logger.debug("Failed to write compaction debug log to %s", log_path)
|
||||
|
||||
|
||||
async def log_compaction(
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
ratio_before: float,
|
||||
event_bus: EventBus | None,
|
||||
*,
|
||||
pre_inventory: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Log compaction result to runtime logger and event bus."""
|
||||
ratio_after = conversation.usage_ratio()
|
||||
before_pct = round(ratio_before * 100)
|
||||
after_pct = round(ratio_after * 100)
|
||||
|
||||
# Determine label from what happened
|
||||
if after_pct >= before_pct - 1:
|
||||
level = "prune_only"
|
||||
elif ratio_after <= 0.6:
|
||||
level = "llm"
|
||||
else:
|
||||
level = "structural"
|
||||
|
||||
logger.info(
|
||||
"Compaction complete (%s): %d%% -> %d%%",
|
||||
level,
|
||||
before_pct,
|
||||
after_pct,
|
||||
)
|
||||
|
||||
if ctx.runtime_logger:
|
||||
ctx.runtime_logger.log_step(
|
||||
node_id=ctx.node_id,
|
||||
node_type="event_loop",
|
||||
step_index=-1,
|
||||
llm_text=f"Context compacted ({level}): {before_pct}% \u2192 {after_pct}%",
|
||||
verdict="COMPACTION",
|
||||
verdict_feedback=f"level={level} before={before_pct}% after={after_pct}%",
|
||||
)
|
||||
|
||||
if event_bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
event_data: dict[str, Any] = {
|
||||
"level": level,
|
||||
"usage_before": before_pct,
|
||||
"usage_after": after_pct,
|
||||
}
|
||||
if pre_inventory is not None:
|
||||
event_data["message_inventory"] = pre_inventory
|
||||
await event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_COMPACTED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data=event_data,
|
||||
)
|
||||
)
|
||||
|
||||
await publish_context_usage(event_bus, ctx, conversation, "post_compaction")
|
||||
|
||||
if os.environ.get("HIVE_COMPACTION_DEBUG"):
|
||||
write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory)
|
||||
|
||||
|
||||
def build_emergency_summary(
|
||||
ctx: NodeContext,
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
conversation: NodeConversation | None = None,
|
||||
config: LoopConfig | None = None,
|
||||
) -> str:
|
||||
"""Build a structured emergency compaction summary.
|
||||
|
||||
Unlike normal/aggressive compaction which uses an LLM summary,
|
||||
emergency compaction cannot afford an LLM call (context is already
|
||||
way over budget). Instead, build a deterministic summary from the
|
||||
node's known state so the LLM can continue working after
|
||||
compaction without losing track of its task and inputs.
|
||||
"""
|
||||
parts = [
|
||||
"EMERGENCY COMPACTION — previous conversation was too large "
|
||||
"and has been replaced with this summary.\n"
|
||||
]
|
||||
|
||||
# 1. Node identity
|
||||
spec = ctx.node_spec
|
||||
parts.append(f"NODE: {spec.name} (id={spec.id})")
|
||||
if spec.description:
|
||||
parts.append(f"PURPOSE: {spec.description}")
|
||||
|
||||
# 2. Inputs the node received
|
||||
input_lines = []
|
||||
for key in spec.input_keys:
|
||||
value = ctx.input_data.get(key) or ctx.memory.read(key)
|
||||
if value is not None:
|
||||
# Truncate long values but keep them recognisable
|
||||
v_str = str(value)
|
||||
if len(v_str) > 200:
|
||||
v_str = v_str[:200] + "…"
|
||||
input_lines.append(f" {key}: {v_str}")
|
||||
if input_lines:
|
||||
parts.append("INPUTS:\n" + "\n".join(input_lines))
|
||||
|
||||
# 3. Output accumulator state (what's been set so far)
|
||||
if accumulator:
|
||||
acc_state = accumulator.to_dict()
|
||||
set_keys = {k: v for k, v in acc_state.items() if v is not None}
|
||||
missing = [k for k, v in acc_state.items() if v is None]
|
||||
if set_keys:
|
||||
lines = [f" {k}: {str(v)[:150]}" for k, v in set_keys.items()]
|
||||
parts.append("OUTPUTS ALREADY SET:\n" + "\n".join(lines))
|
||||
if missing:
|
||||
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(missing)}")
|
||||
elif spec.output_keys:
|
||||
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
|
||||
|
||||
# 4. Available tools reminder
|
||||
if spec.tools:
|
||||
parts.append(f"AVAILABLE TOOLS: {', '.join(spec.tools)}")
|
||||
|
||||
# 5. Spillover files — list actual files so the LLM can load
|
||||
# them immediately instead of having to call list_data_files first.
|
||||
# Inline adapt.md (agent memory) directly — it contains user rules
|
||||
# and identity preferences that must survive emergency compaction.
|
||||
spillover_dir = config.spillover_dir if config else None
|
||||
if spillover_dir:
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
data_dir = Path(spillover_dir)
|
||||
if data_dir.is_dir():
|
||||
# Inline adapt.md content directly
|
||||
adapt_path = data_dir / "adapt.md"
|
||||
if adapt_path.is_file():
|
||||
adapt_text = adapt_path.read_text(encoding="utf-8").strip()
|
||||
if adapt_text:
|
||||
parts.append(f"AGENT MEMORY (adapt.md):\n{adapt_text}")
|
||||
|
||||
all_files = sorted(
|
||||
f.name for f in data_dir.iterdir() if f.is_file() and f.name != "adapt.md"
|
||||
)
|
||||
# Separate conversation history files from regular data files
|
||||
conv_files = [f for f in all_files if re.match(r"conversation_\d+\.md$", f)]
|
||||
data_files = [f for f in all_files if f not in conv_files]
|
||||
|
||||
if conv_files:
|
||||
conv_list = "\n".join(
|
||||
f" - {f} (full path: {data_dir / f})" for f in conv_files
|
||||
)
|
||||
parts.append(
|
||||
"CONVERSATION HISTORY (freeform messages saved during compaction — "
|
||||
"use load_data('<filename>') to review earlier dialogue):\n" + conv_list
|
||||
)
|
||||
if data_files:
|
||||
file_list = "\n".join(
|
||||
f" - {f} (full path: {data_dir / f})" for f in data_files[:30]
|
||||
)
|
||||
parts.append("DATA FILES (use load_data('<filename>') to read):\n" + file_list)
|
||||
if not all_files:
|
||||
parts.append(
|
||||
"NOTE: Large tool results may have been saved to files. "
|
||||
"Use list_directory to check the data directory."
|
||||
)
|
||||
except Exception:
|
||||
parts.append(
|
||||
"NOTE: Large tool results were saved to files. "
|
||||
"Use read_file(path='<path>') to read them."
|
||||
)
|
||||
|
||||
# 6. Tool call history (prevent re-calling tools)
|
||||
if conversation is not None:
|
||||
tool_history = _extract_tool_call_history(conversation)
|
||||
if tool_history:
|
||||
parts.append(tool_history)
|
||||
|
||||
parts.append(
|
||||
"\nContinue working towards setting the remaining outputs. "
|
||||
"Use your tools and the inputs above."
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _extract_tool_call_history(conversation: NodeConversation) -> str:
|
||||
"""Extract tool call history from conversation messages.
|
||||
|
||||
This is the instance-level variant that operates on a NodeConversation
|
||||
directly (vs. the module-level extract_tool_call_history in conversation.py
|
||||
which works on raw message lists).
|
||||
"""
|
||||
from framework.graph.conversation import extract_tool_call_history
|
||||
|
||||
return extract_tool_call_history(list(conversation.messages))
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Cursor persistence, queue draining, and pause detection.
|
||||
|
||||
Handles the checkpoint/resume cycle: restoring state from a previous
|
||||
conversation store, writing cursor data, and managing injection/trigger
|
||||
queues between iterations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.conversation import ConversationStore, NodeConversation
|
||||
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator, TriggerEvent
|
||||
from framework.graph.node import NodeContext
|
||||
from framework.llm.capabilities import supports_image_tool_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RestoredState:
|
||||
"""State recovered from a previous checkpoint."""
|
||||
|
||||
conversation: NodeConversation
|
||||
accumulator: OutputAccumulator
|
||||
start_iteration: int
|
||||
recent_responses: list[str]
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]]
|
||||
|
||||
|
||||
async def restore(
|
||||
conversation_store: ConversationStore | None,
|
||||
ctx: NodeContext,
|
||||
config: LoopConfig,
|
||||
) -> RestoredState | None:
|
||||
"""Attempt to restore from a previous checkpoint.
|
||||
|
||||
Returns a ``RestoredState`` with conversation, accumulator, iteration
|
||||
counter, and stall/doom-loop detection state — everything needed to
|
||||
resume exactly where execution stopped.
|
||||
"""
|
||||
if conversation_store is None:
|
||||
return None
|
||||
|
||||
# In isolated mode, filter parts by phase_id so the node only sees
|
||||
# its own messages in the shared flat conversation store. In
|
||||
# continuous mode (or when _restore is called for timer-resume)
|
||||
# load all parts — the full conversation threads across nodes.
|
||||
_is_continuous = getattr(ctx, "continuous_mode", False)
|
||||
phase_filter = None if _is_continuous else ctx.node_id
|
||||
conversation = await NodeConversation.restore(
|
||||
conversation_store,
|
||||
phase_id=phase_filter,
|
||||
)
|
||||
if conversation is None:
|
||||
return None
|
||||
|
||||
accumulator = await OutputAccumulator.restore(conversation_store)
|
||||
accumulator.spillover_dir = config.spillover_dir
|
||||
accumulator.max_value_chars = config.max_output_value_chars
|
||||
|
||||
cursor = await conversation_store.read_cursor()
|
||||
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
|
||||
|
||||
# Restore stall/doom-loop detection state
|
||||
recent_responses: list[str] = cursor.get("recent_responses", []) if cursor else []
|
||||
raw_fps = cursor.get("recent_tool_fingerprints", []) if cursor else []
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]] = [
|
||||
[tuple(pair) for pair in fps] # type: ignore[misc]
|
||||
for fps in raw_fps
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Restored event loop: iteration={start_iteration}, "
|
||||
f"messages={conversation.message_count}, "
|
||||
f"outputs={list(accumulator.values.keys())}, "
|
||||
f"stall_window={len(recent_responses)}, "
|
||||
f"doom_window={len(recent_tool_fingerprints)}"
|
||||
)
|
||||
return RestoredState(
|
||||
conversation=conversation,
|
||||
accumulator=accumulator,
|
||||
start_iteration=start_iteration,
|
||||
recent_responses=recent_responses,
|
||||
recent_tool_fingerprints=recent_tool_fingerprints,
|
||||
)
|
||||
|
||||
|
||||
async def write_cursor(
|
||||
conversation_store: ConversationStore | None,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
iteration: int,
|
||||
*,
|
||||
recent_responses: list[str] | None = None,
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]] | None = None,
|
||||
) -> None:
|
||||
"""Write checkpoint cursor for crash recovery.
|
||||
|
||||
Persists iteration counter, accumulator outputs, and stall/doom-loop
|
||||
detection state so that resume picks up exactly where execution stopped.
|
||||
"""
|
||||
if conversation_store:
|
||||
cursor = await conversation_store.read_cursor() or {}
|
||||
cursor.update(
|
||||
{
|
||||
"iteration": iteration,
|
||||
"node_id": ctx.node_id,
|
||||
"next_seq": conversation.next_seq,
|
||||
"outputs": accumulator.to_dict(),
|
||||
}
|
||||
)
|
||||
# Persist stall/doom-loop detection state for reliable resume
|
||||
if recent_responses is not None:
|
||||
cursor["recent_responses"] = recent_responses
|
||||
if recent_tool_fingerprints is not None:
|
||||
# Convert list[list[tuple]] → list[list[list]] for JSON
|
||||
cursor["recent_tool_fingerprints"] = [
|
||||
[list(pair) for pair in fps] for fps in recent_tool_fingerprints
|
||||
]
|
||||
await conversation_store.write_cursor(cursor)
|
||||
|
||||
|
||||
async def drain_injection_queue(
|
||||
queue: asyncio.Queue,
|
||||
conversation: NodeConversation,
|
||||
*,
|
||||
ctx: NodeContext,
|
||||
describe_images_as_text_fn: (
|
||||
Callable[[list[dict[str, Any]]], Awaitable[str | None]] | None
|
||||
) = None,
|
||||
) -> int:
|
||||
"""Drain all pending injected events as user messages. Returns count."""
|
||||
count = 0
|
||||
while not queue.empty():
|
||||
try:
|
||||
content, is_client_input, image_content = queue.get_nowait()
|
||||
logger.info(
|
||||
"[drain] injected message (client_input=%s, images=%d): %s",
|
||||
is_client_input,
|
||||
len(image_content) if image_content else 0,
|
||||
content[:200] if content else "(empty)",
|
||||
)
|
||||
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
|
||||
logger.info(
|
||||
"Model '%s' does not support images; attempting vision fallback",
|
||||
ctx.llm.model,
|
||||
)
|
||||
if describe_images_as_text_fn is not None:
|
||||
description = await describe_images_as_text_fn(image_content)
|
||||
if description:
|
||||
content = f"{content}\n\n{description}" if content else description
|
||||
logger.info("[drain] image described as text via vision fallback")
|
||||
else:
|
||||
logger.info("[drain] no vision fallback available; images dropped")
|
||||
image_content = None
|
||||
# Real user input is stored as-is; external events get a prefix
|
||||
if is_client_input:
|
||||
await conversation.add_user_message(
|
||||
content,
|
||||
is_client_input=True,
|
||||
image_content=image_content,
|
||||
)
|
||||
else:
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
count += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return count
|
||||
|
||||
|
||||
async def drain_trigger_queue(
|
||||
queue: asyncio.Queue,
|
||||
conversation: NodeConversation,
|
||||
) -> int:
|
||||
"""Drain all pending trigger events as a single batched user message.
|
||||
|
||||
Multiple triggers are merged so the LLM sees them atomically and can
|
||||
reason about all pending triggers before acting.
|
||||
"""
|
||||
triggers: list[TriggerEvent] = []
|
||||
while not queue.empty():
|
||||
try:
|
||||
triggers.append(queue.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if not triggers:
|
||||
return 0
|
||||
|
||||
parts: list[str] = []
|
||||
for t in triggers:
|
||||
task = t.payload.get("task", "")
|
||||
task_line = f"\nTask: {task}" if task else ""
|
||||
payload_str = json.dumps(t.payload, default=str)
|
||||
parts.append(f"[TRIGGER: {t.trigger_type}/{t.source_id}]{task_line}\n{payload_str}")
|
||||
|
||||
combined = "\n\n".join(parts)
|
||||
logger.info("[drain] %d trigger(s): %s", len(triggers), combined[:200])
|
||||
await conversation.add_user_message(combined)
|
||||
return len(triggers)
|
||||
|
||||
|
||||
async def check_pause(
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
iteration: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if pause has been requested. Returns True if paused.
|
||||
|
||||
Note: This check happens BEFORE starting iteration N, after completing N-1.
|
||||
If paused, the node exits having completed {iteration} iterations (0 to iteration-1).
|
||||
"""
|
||||
# Check executor-level pause event (for /pause command, Ctrl+Z)
|
||||
if ctx.pause_event and ctx.pause_event.is_set():
|
||||
completed = iteration # 0-indexed: iteration=3 means 3 iterations completed (0,1,2)
|
||||
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (executor-level)")
|
||||
return True
|
||||
|
||||
# Check context-level pause flags (legacy/alternative methods)
|
||||
pause_requested = ctx.input_data.get("pause_requested", False)
|
||||
if not pause_requested:
|
||||
try:
|
||||
pause_requested = ctx.memory.read("pause_requested") or False
|
||||
except (PermissionError, KeyError):
|
||||
pause_requested = False
|
||||
if pause_requested:
|
||||
completed = iteration
|
||||
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (context-level)")
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,360 @@
|
||||
"""EventBus publishing helpers for the event loop.
|
||||
|
||||
Thin wrappers around EventBus.emit_*() calls that check for bus existence
|
||||
before publishing. Extracted to reduce noise in the main orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop.types import HookContext
|
||||
from framework.graph.node import NodeContext
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def publish_loop_started(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
max_iterations: int,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_loop_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
max_iterations=max_iterations,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def generate_action_plan(
|
||||
event_bus: EventBus | None,
|
||||
ctx: NodeContext,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str,
|
||||
) -> None:
|
||||
"""Generate a brief action plan via LLM and emit it as an SSE event.
|
||||
|
||||
Runs as a fire-and-forget task so it never blocks the main loop.
|
||||
"""
|
||||
try:
|
||||
system_prompt = ctx.node_spec.system_prompt or ""
|
||||
# Trim to keep the prompt small
|
||||
prompt_summary = system_prompt[:500]
|
||||
if len(system_prompt) > 500:
|
||||
prompt_summary += "..."
|
||||
|
||||
tool_names = [t.name for t in ctx.available_tools]
|
||||
output_keys = ctx.node_spec.output_keys or []
|
||||
|
||||
prompt = (
|
||||
f'You are about to work on a task as node "{node_id}".\n\n'
|
||||
f"System prompt:\n{prompt_summary}\n\n"
|
||||
f"Tools available: {tool_names}\n"
|
||||
f"Required outputs: {output_keys}\n\n"
|
||||
f"Write a brief action plan (2-5 bullet points) describing "
|
||||
f"what you will do to complete this task. Be specific and concise.\n"
|
||||
f"Return ONLY the plan text, no preamble."
|
||||
)
|
||||
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
plan = response.content.strip()
|
||||
if plan and event_bus:
|
||||
await event_bus.emit_node_action_plan(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
plan=plan,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Action plan generation failed for node '%s': %s", node_id, e)
|
||||
|
||||
|
||||
async def publish_iteration(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
execution_id: str = "",
|
||||
extra_data: dict | None = None,
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_loop_iteration(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iteration=iteration,
|
||||
execution_id=execution_id,
|
||||
extra_data=extra_data,
|
||||
)
|
||||
|
||||
|
||||
async def publish_llm_turn_complete(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
stop_reason: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_llm_turn_complete(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
stop_reason=stop_reason,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
|
||||
def log_skip_judge(
|
||||
ctx: NodeContext,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
feedback: str,
|
||||
tool_calls: list[dict],
|
||||
llm_text: str,
|
||||
turn_tokens: dict[str, int],
|
||||
iter_start: float,
|
||||
) -> None:
|
||||
"""Log a CONTINUE step that skips judge evaluation (e.g., waiting for input)."""
|
||||
if ctx.runtime_logger:
|
||||
ctx.runtime_logger.log_step(
|
||||
node_id=node_id,
|
||||
node_type="event_loop",
|
||||
step_index=iteration,
|
||||
verdict="CONTINUE",
|
||||
verdict_feedback=feedback,
|
||||
tool_calls=tool_calls,
|
||||
llm_text=llm_text,
|
||||
input_tokens=turn_tokens.get("input", 0),
|
||||
output_tokens=turn_tokens.get("output", 0),
|
||||
latency_ms=int((time.time() - iter_start) * 1000),
|
||||
)
|
||||
|
||||
|
||||
async def publish_loop_completed(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iterations: int,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_loop_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iterations=iterations,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_context_usage(
|
||||
event_bus: EventBus | None,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
trigger: str,
|
||||
) -> None:
|
||||
"""Emit a CONTEXT_USAGE_UPDATED event with current context window state."""
|
||||
if not event_bus:
|
||||
return
|
||||
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
estimated = conversation.estimate_tokens()
|
||||
max_tokens = conversation._max_context_tokens
|
||||
ratio = estimated / max_tokens if max_tokens > 0 else 0.0
|
||||
await event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_USAGE_UPDATED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"usage_ratio": round(ratio, 4),
|
||||
"usage_pct": round(ratio * 100),
|
||||
"message_count": conversation.message_count,
|
||||
"estimated_tokens": estimated,
|
||||
"max_context_tokens": max_tokens,
|
||||
"trigger": trigger,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def publish_stalled(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_stalled(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
reason="Consecutive similar responses detected",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_text_delta(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
ctx: NodeContext,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
inner_turn: int = 0,
|
||||
) -> None:
|
||||
if event_bus:
|
||||
if ctx.node_spec.client_facing:
|
||||
await event_bus.emit_client_output_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
inner_turn=inner_turn,
|
||||
)
|
||||
else:
|
||||
await event_bus.emit_llm_text_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
execution_id=execution_id,
|
||||
inner_turn=inner_turn,
|
||||
)
|
||||
|
||||
|
||||
async def publish_tool_started(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_tool_call_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_tool_completed(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
is_error: bool,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_tool_call_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
result=result,
|
||||
is_error=is_error,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_judge_verdict(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
action: str,
|
||||
feedback: str = "",
|
||||
judge_type: str = "implicit",
|
||||
iteration: int = 0,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_judge_verdict(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
action=action,
|
||||
feedback=feedback,
|
||||
judge_type=judge_type,
|
||||
iteration=iteration,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_output_key_set(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
key: str,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_output_key_set(
|
||||
stream_id=stream_id, node_id=node_id, key=key, execution_id=execution_id
|
||||
)
|
||||
|
||||
|
||||
async def run_hooks(
|
||||
hooks_config: dict[str, list],
|
||||
event: str,
|
||||
conversation: NodeConversation,
|
||||
trigger: str | None = None,
|
||||
) -> None:
|
||||
"""Run all registered hooks for *event*, applying their results.
|
||||
|
||||
Each hook receives a HookContext and may return a HookResult that:
|
||||
- replaces the system prompt (result.system_prompt)
|
||||
- injects an extra user message (result.inject)
|
||||
Hooks run in registration order; each sees the prompt as left by the
|
||||
previous hook.
|
||||
"""
|
||||
hook_list = hooks_config.get(event, [])
|
||||
if not hook_list:
|
||||
return
|
||||
for hook in hook_list:
|
||||
ctx = HookContext(
|
||||
event=event,
|
||||
trigger=trigger,
|
||||
system_prompt=conversation.system_prompt,
|
||||
)
|
||||
try:
|
||||
result = await hook(ctx)
|
||||
except Exception:
|
||||
logger.warning("Hook '%s' raised an exception", event, exc_info=True)
|
||||
continue
|
||||
if result is None:
|
||||
continue
|
||||
if result.system_prompt:
|
||||
conversation.update_system_prompt(result.system_prompt)
|
||||
if result.inject:
|
||||
await conversation.add_user_message(result.inject)
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Judge evaluation pipeline for the event loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop.types import JudgeProtocol, JudgeVerdict, OutputAccumulator
|
||||
from framework.graph.node import NodeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubagentJudge:
|
||||
"""Judge for subagent execution."""
|
||||
|
||||
def __init__(self, task: str, max_iterations: int = 10):
|
||||
self._task = task
|
||||
self._max_iterations = max_iterations
|
||||
|
||||
async def evaluate(self, context: dict[str, object]) -> JudgeVerdict:
|
||||
missing = context.get("missing_keys", [])
|
||||
if not isinstance(missing, list) or not missing:
|
||||
return JudgeVerdict(action="ACCEPT", feedback="")
|
||||
|
||||
iteration = context.get("iteration", 0)
|
||||
if not isinstance(iteration, int):
|
||||
iteration = 0
|
||||
remaining = self._max_iterations - iteration - 1
|
||||
|
||||
if remaining <= 3:
|
||||
urgency = (
|
||||
f"URGENT: Only {remaining} iterations left. "
|
||||
f"Stop all other work and call set_output NOW for: {missing}"
|
||||
)
|
||||
elif remaining <= self._max_iterations // 2:
|
||||
urgency = (
|
||||
f"WARNING: {remaining} iterations remaining. "
|
||||
f"You must call set_output for: {missing}"
|
||||
)
|
||||
else:
|
||||
urgency = f"Missing output keys: {missing}. Use set_output to provide them."
|
||||
|
||||
return JudgeVerdict(action="RETRY", feedback=f"Your task: {self._task}\n{urgency}")
|
||||
|
||||
|
||||
async def judge_turn(
|
||||
*,
|
||||
mark_complete_flag: bool,
|
||||
judge: JudgeProtocol | None,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
assistant_text: str,
|
||||
tool_results: list[dict[str, object]],
|
||||
iteration: int,
|
||||
get_missing_output_keys_fn: Callable[
|
||||
[OutputAccumulator, list[str] | None, list[str] | None],
|
||||
list[str],
|
||||
],
|
||||
max_context_tokens: int,
|
||||
) -> JudgeVerdict:
|
||||
"""Evaluate the current state using judge or implicit logic.
|
||||
|
||||
Evaluation levels (in order):
|
||||
0. Short-circuits: mark_complete, skip_judge, tool-continue.
|
||||
1. Custom judge (JudgeProtocol) — full authority when set.
|
||||
2. Implicit judge — output-key check + optional conversation-aware
|
||||
quality gate (when ``success_criteria`` is defined).
|
||||
|
||||
Returns a JudgeVerdict. ``feedback=None`` means no real evaluation
|
||||
happened (skip_judge, tool-continue); the caller must not inject a
|
||||
feedback message. Any non-None feedback (including ``""``) means a
|
||||
real evaluation occurred and will be logged into the conversation.
|
||||
"""
|
||||
# --- Level 0: short-circuits (no evaluation) -----------------------
|
||||
|
||||
if mark_complete_flag:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
if ctx.node_spec.skip_judge:
|
||||
return JudgeVerdict(action="RETRY") # feedback=None → not logged
|
||||
|
||||
# --- Level 1: custom judge -----------------------------------------
|
||||
|
||||
if judge is not None:
|
||||
context = {
|
||||
"assistant_text": assistant_text,
|
||||
"tool_calls": tool_results,
|
||||
"output_accumulator": accumulator.to_dict(),
|
||||
"accumulator": accumulator,
|
||||
"iteration": iteration,
|
||||
"conversation_summary": conversation.export_summary(),
|
||||
"output_keys": ctx.node_spec.output_keys,
|
||||
"missing_keys": get_missing_output_keys_fn(
|
||||
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
|
||||
),
|
||||
}
|
||||
verdict = await judge.evaluate(context)
|
||||
# Ensure evaluated RETRY always carries feedback for logging.
|
||||
if verdict.action == "RETRY" and not verdict.feedback:
|
||||
return JudgeVerdict(action="RETRY", feedback="Custom judge returned RETRY.")
|
||||
return verdict
|
||||
|
||||
# --- Level 2: implicit judge ---------------------------------------
|
||||
|
||||
# Real tool calls were made — let the agent keep working.
|
||||
if tool_results:
|
||||
return JudgeVerdict(action="RETRY") # feedback=None → not logged
|
||||
|
||||
missing = get_missing_output_keys_fn(
|
||||
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
|
||||
)
|
||||
|
||||
if missing:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
f"Task incomplete. Required outputs not yet produced: {missing}. "
|
||||
f"Follow your system prompt instructions to complete the work."
|
||||
),
|
||||
)
|
||||
|
||||
# All output keys present — run safety checks before accepting.
|
||||
|
||||
output_keys = ctx.node_spec.output_keys or []
|
||||
nullable_keys = set(ctx.node_spec.nullable_output_keys or [])
|
||||
|
||||
# All-nullable with nothing set → node produced nothing useful.
|
||||
all_nullable = output_keys and nullable_keys >= set(output_keys)
|
||||
none_set = not any(accumulator.get(k) is not None for k in output_keys)
|
||||
if all_nullable and none_set:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
f"No output keys have been set yet. "
|
||||
f"Use set_output to set at least one of: {output_keys}"
|
||||
),
|
||||
)
|
||||
|
||||
# Client-facing with no output keys → continuous interaction node.
|
||||
# Inject tool-use pressure instead of auto-accepting.
|
||||
if not output_keys and ctx.node_spec.client_facing:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
"STOP describing what you will do. "
|
||||
"You have FULL access to all tools — file creation, "
|
||||
"shell commands, MCP tools — and you CAN call them "
|
||||
"directly in your response. Respond ONLY with tool "
|
||||
"calls, no prose. Execute the task now."
|
||||
),
|
||||
)
|
||||
|
||||
# Level 2b: conversation-aware quality check (if success_criteria set)
|
||||
if ctx.node_spec.success_criteria and ctx.llm:
|
||||
from framework.graph.conversation_judge import evaluate_phase_completion
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=ctx.llm,
|
||||
conversation=conversation,
|
||||
phase_name=ctx.node_spec.name,
|
||||
phase_description=ctx.node_spec.description,
|
||||
success_criteria=ctx.node_spec.success_criteria,
|
||||
accumulator_state=accumulator.to_dict(),
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
if verdict.action != "ACCEPT":
|
||||
return JudgeVerdict(
|
||||
action=verdict.action,
|
||||
feedback=verdict.feedback or "Phase criteria not met.",
|
||||
)
|
||||
|
||||
return JudgeVerdict(action="ACCEPT", feedback="")
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Stall and doom-loop detection for the event loop.
|
||||
|
||||
Pure functions with no class dependencies — safe to call from any context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def ngram_similarity(s1: str, s2: str, n: int = 2) -> float:
|
||||
"""Jaccard similarity of n-gram sets.
|
||||
|
||||
Returns 0.0-1.0, where 1.0 is exact match.
|
||||
Fast: O(len(s) + len(s2)) using set operations.
|
||||
"""
|
||||
|
||||
def _ngrams(s: str) -> set[str]:
|
||||
return {s[i : i + n] for i in range(len(s) - n + 1) if s.strip()}
|
||||
|
||||
if not s1 or not s2:
|
||||
return 0.0
|
||||
|
||||
ngrams1, ngrams2 = _ngrams(s1.lower()), _ngrams(s2.lower())
|
||||
if not ngrams1 or not ngrams2:
|
||||
return 0.0
|
||||
|
||||
intersection = len(ngrams1 & ngrams2)
|
||||
union = len(ngrams1 | ngrams2)
|
||||
return intersection / union if union else 0.0
|
||||
|
||||
|
||||
def is_stalled(
|
||||
recent_responses: list[str],
|
||||
threshold: int,
|
||||
similarity_threshold: float,
|
||||
) -> bool:
|
||||
"""Detect stall using n-gram similarity.
|
||||
|
||||
Detects when ALL N consecutive responses are mutually similar
|
||||
(>= threshold). A single dissimilar response resets the signal.
|
||||
This catches phrases like "I'm still stuck" vs "I'm stuck"
|
||||
without false-positives on "attempt 1" vs "attempt 2".
|
||||
"""
|
||||
if len(recent_responses) < threshold:
|
||||
return False
|
||||
if not recent_responses[0]:
|
||||
return False
|
||||
|
||||
# Every consecutive pair must be similar
|
||||
for i in range(1, len(recent_responses)):
|
||||
if ngram_similarity(recent_responses[i], recent_responses[i - 1]) < similarity_threshold:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def fingerprint_tool_calls(
|
||||
tool_results: list[dict],
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Create deterministic fingerprints for a turn's tool calls.
|
||||
|
||||
Each fingerprint is (tool_name, canonical_args_json). Order-sensitive
|
||||
so [search("a"), fetch("b")] != [fetch("b"), search("a")].
|
||||
"""
|
||||
fingerprints = []
|
||||
for tr in tool_results:
|
||||
name = tr.get("tool_name", "")
|
||||
args = tr.get("tool_input", {})
|
||||
try:
|
||||
canonical = json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
canonical = str(args)
|
||||
fingerprints.append((name, canonical))
|
||||
return fingerprints
|
||||
|
||||
|
||||
def is_tool_doom_loop(
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]],
|
||||
threshold: int,
|
||||
enabled: bool = True,
|
||||
) -> tuple[bool, str]:
|
||||
"""Detect doom loop via exact fingerprint match.
|
||||
|
||||
Detects when N consecutive turns invoke the same tools with
|
||||
identical (canonicalized) arguments. Different arguments mean
|
||||
different work, so only exact matches count.
|
||||
|
||||
Returns (is_doom_loop, description).
|
||||
"""
|
||||
if not enabled:
|
||||
return False, ""
|
||||
if len(recent_tool_fingerprints) < threshold:
|
||||
return False, ""
|
||||
first = recent_tool_fingerprints[0]
|
||||
if not first:
|
||||
return False, ""
|
||||
|
||||
# All turns in the window must match the first exactly
|
||||
if all(fp == first for fp in recent_tool_fingerprints[1:]):
|
||||
tool_names = [name for name, _ in first]
|
||||
desc = (
|
||||
f"Doom loop detected: {len(recent_tool_fingerprints)} "
|
||||
f"identical consecutive tool calls ({', '.join(tool_names)})"
|
||||
)
|
||||
return True, desc
|
||||
return False, ""
|
||||
@@ -0,0 +1,412 @@
|
||||
"""Subagent execution for the event loop.
|
||||
|
||||
Handles the full subagent lifecycle: validation, context setup, tool filtering,
|
||||
conversation store derivation, execution, and cleanup. Also includes the
|
||||
_EscalationReceiver helper used for subagent → queen escalation routing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.conversation import ConversationStore
|
||||
from framework.graph.event_loop.judge_pipeline import SubagentJudge
|
||||
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator
|
||||
from framework.graph.node import NodeContext, SharedMemory
|
||||
from framework.llm.provider import ToolResult, ToolUse
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.event_loop_node import EventLoopNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EscalationReceiver:
|
||||
"""Temporary receiver registered in node_registry for subagent escalation routing.
|
||||
|
||||
When a subagent calls ``report_to_parent(wait_for_response=True)``, the callback
|
||||
creates one of these, registers it under a unique escalation ID in the executor's
|
||||
``node_registry``, and awaits ``wait()``. The TUI / runner calls
|
||||
``inject_input(escalation_id, content)`` which the ``ExecutionStream`` routes here
|
||||
via ``inject_event()`` — matching the same ``hasattr(node, "inject_event")`` check
|
||||
used for regular ``EventLoopNode`` instances.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event = asyncio.Event()
|
||||
self._response: str | None = None
|
||||
self._awaiting_input = True # So inject_worker_message() can prefer us
|
||||
|
||||
async def inject_event(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Called by ExecutionStream.inject_input() when the user responds."""
|
||||
self._response = content
|
||||
self._event.set()
|
||||
|
||||
async def wait(self) -> str | None:
|
||||
"""Block until inject_event() delivers the user's response."""
|
||||
await self._event.wait()
|
||||
return self._response
|
||||
|
||||
|
||||
async def execute_subagent(
|
||||
ctx: NodeContext,
|
||||
agent_id: str,
|
||||
task: str,
|
||||
*,
|
||||
config: LoopConfig,
|
||||
event_loop_node_cls: type[EventLoopNode],
|
||||
escalation_receiver_cls: type[EscalationReceiver],
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
event_bus: EventBus | None = None,
|
||||
tool_executor: Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None = None,
|
||||
conversation_store: ConversationStore | None = None,
|
||||
subagent_instance_counter: dict[str, int] | None = None,
|
||||
) -> ToolResult:
|
||||
"""Execute a subagent and return the result as a ToolResult.
|
||||
|
||||
The subagent:
|
||||
- Gets a fresh conversation with just the task
|
||||
- Has read-only access to the parent's readable memory
|
||||
- Cannot delegate to its own subagents (prevents recursion)
|
||||
- Returns its output in structured JSON format
|
||||
|
||||
Args:
|
||||
ctx: Parent node's context (for memory, tools, LLM access).
|
||||
agent_id: The node ID of the subagent to invoke.
|
||||
task: The task description to give the subagent.
|
||||
accumulator: Parent's OutputAccumulator.
|
||||
event_bus: EventBus for lifecycle events.
|
||||
config: LoopConfig for iteration/tool limits.
|
||||
tool_executor: Tool executor callable.
|
||||
conversation_store: Parent conversation store (for deriving subagent store).
|
||||
subagent_instance_counter: Mutable counter dict for unique subagent paths.
|
||||
|
||||
Returns:
|
||||
ToolResult with structured JSON output.
|
||||
"""
|
||||
# Log subagent invocation start
|
||||
logger.info(
|
||||
"\n" + "=" * 60 + "\n"
|
||||
"🤖 SUBAGENT INVOCATION\n"
|
||||
"=" * 60 + "\n"
|
||||
"Parent Node: %s\n"
|
||||
"Subagent ID: %s\n"
|
||||
"Task: %s\n" + "=" * 60,
|
||||
ctx.node_id,
|
||||
agent_id,
|
||||
task[:500] + "..." if len(task) > 500 else task,
|
||||
)
|
||||
|
||||
# 1. Validate agent exists in registry
|
||||
if agent_id not in ctx.node_registry:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=json.dumps(
|
||||
{
|
||||
"message": f"Sub-agent '{agent_id}' not found in registry",
|
||||
"data": None,
|
||||
"metadata": {"agent_id": agent_id, "success": False, "error": "not_found"},
|
||||
}
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
subagent_spec = ctx.node_registry[agent_id]
|
||||
|
||||
# 2. Create read-only memory snapshot
|
||||
parent_data = ctx.memory.read_all()
|
||||
|
||||
# Merge in-flight outputs from the parent's accumulator.
|
||||
if accumulator:
|
||||
for key, value in accumulator.to_dict().items():
|
||||
if key not in parent_data:
|
||||
parent_data[key] = value
|
||||
|
||||
subagent_memory = SharedMemory()
|
||||
for key, value in parent_data.items():
|
||||
subagent_memory.write(key, value, validate=False)
|
||||
|
||||
read_keys = set(parent_data.keys()) | set(subagent_spec.input_keys or [])
|
||||
scoped_memory = subagent_memory.with_permissions(
|
||||
read_keys=list(read_keys),
|
||||
write_keys=[], # Read-only!
|
||||
)
|
||||
|
||||
# 2b. Compute instance counter early so the callback and child context
|
||||
# share the same stable node_id for this subagent invocation.
|
||||
if subagent_instance_counter is not None:
|
||||
subagent_instance_counter.setdefault(agent_id, 0)
|
||||
subagent_instance_counter[agent_id] += 1
|
||||
subagent_instance = str(subagent_instance_counter[agent_id])
|
||||
else:
|
||||
subagent_instance = "1"
|
||||
|
||||
if subagent_instance == "1":
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}"
|
||||
else:
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}:{subagent_instance}"
|
||||
|
||||
# 2c. Set up report callback (one-way channel to parent / event bus)
|
||||
subagent_reports: list[dict] = []
|
||||
|
||||
async def _report_callback(
|
||||
message: str,
|
||||
data: dict | None = None,
|
||||
*,
|
||||
wait_for_response: bool = False,
|
||||
) -> str | None:
|
||||
subagent_reports.append({"message": message, "data": data, "timestamp": time.time()})
|
||||
if event_bus:
|
||||
await event_bus.emit_subagent_report(
|
||||
stream_id=ctx.node_id,
|
||||
node_id=sa_node_id,
|
||||
subagent_id=agent_id,
|
||||
message=message,
|
||||
data=data,
|
||||
execution_id=ctx.execution_id,
|
||||
)
|
||||
|
||||
if not wait_for_response:
|
||||
return None
|
||||
|
||||
if not event_bus:
|
||||
logger.warning(
|
||||
"Subagent '%s' requested user response but no event_bus available",
|
||||
agent_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Create isolated receiver and register for input routing
|
||||
import uuid
|
||||
|
||||
escalation_id = f"{ctx.node_id}:escalation:{uuid.uuid4().hex[:8]}"
|
||||
receiver = escalation_receiver_cls()
|
||||
registry = ctx.shared_node_registry
|
||||
|
||||
registry[escalation_id] = receiver
|
||||
try:
|
||||
await event_bus.emit_escalation_requested(
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=escalation_id,
|
||||
reason=f"Subagent report (wait_for_response) from {agent_id}",
|
||||
context=message,
|
||||
execution_id=ctx.execution_id,
|
||||
)
|
||||
# Block until queen responds
|
||||
return await receiver.wait()
|
||||
finally:
|
||||
registry.pop(escalation_id, None)
|
||||
|
||||
# 3. Filter tools for subagent
|
||||
subagent_tool_names = set(subagent_spec.tools or [])
|
||||
tool_source = ctx.all_tools if ctx.all_tools else ctx.available_tools
|
||||
|
||||
# GCU auto-population
|
||||
if subagent_spec.node_type == "gcu" and not subagent_tool_names:
|
||||
subagent_tools = [t for t in tool_source if t.name != "delegate_to_sub_agent"]
|
||||
else:
|
||||
subagent_tools = [
|
||||
t
|
||||
for t in tool_source
|
||||
if t.name in subagent_tool_names and t.name != "delegate_to_sub_agent"
|
||||
]
|
||||
|
||||
missing = subagent_tool_names - {t.name for t in subagent_tools}
|
||||
if missing:
|
||||
logger.warning(
|
||||
"Subagent '%s' requested tools not found in catalog: %s",
|
||||
agent_id,
|
||||
sorted(missing),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"📦 Subagent '%s' configuration:\n"
|
||||
" - System prompt: %s\n"
|
||||
" - Tools available (%d): %s\n"
|
||||
" - Memory keys inherited: %s",
|
||||
agent_id,
|
||||
(subagent_spec.system_prompt[:200] + "...")
|
||||
if subagent_spec.system_prompt and len(subagent_spec.system_prompt) > 200
|
||||
else subagent_spec.system_prompt,
|
||||
len(subagent_tools),
|
||||
[t.name for t in subagent_tools],
|
||||
list(parent_data.keys()),
|
||||
)
|
||||
|
||||
# 4. Build subagent context
|
||||
max_iter = min(config.max_iterations, 10)
|
||||
subagent_ctx = NodeContext(
|
||||
runtime=ctx.runtime,
|
||||
node_id=sa_node_id,
|
||||
node_spec=subagent_spec,
|
||||
memory=scoped_memory,
|
||||
input_data={"task": task, **parent_data},
|
||||
llm=ctx.llm,
|
||||
available_tools=subagent_tools,
|
||||
goal_context=(
|
||||
f"Your specific task: {task}\n\n"
|
||||
f"COMPLETION REQUIREMENTS:\n"
|
||||
f"When your task is done, you MUST call set_output() "
|
||||
f"for each required key: {subagent_spec.output_keys}\n"
|
||||
f"Alternatively, call report_to_parent(mark_complete=true) "
|
||||
f"with your findings in message/data.\n"
|
||||
f"You have a maximum of {max_iter} turns to complete this task."
|
||||
),
|
||||
goal=ctx.goal,
|
||||
max_tokens=ctx.max_tokens,
|
||||
runtime_logger=ctx.runtime_logger,
|
||||
is_subagent_mode=True, # Prevents nested delegation
|
||||
report_callback=_report_callback,
|
||||
node_registry={}, # Empty - no nested subagents
|
||||
shared_node_registry=ctx.shared_node_registry, # For escalation routing
|
||||
)
|
||||
|
||||
# 5. Create and execute subagent EventLoopNode
|
||||
subagent_conv_store = None
|
||||
if conversation_store is not None:
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
parent_base = getattr(conversation_store, "_base", None)
|
||||
if parent_base is not None:
|
||||
conversations_dir = parent_base.parent
|
||||
subagent_dir_name = f"{agent_id}-{subagent_instance}"
|
||||
subagent_store_path = conversations_dir / subagent_dir_name
|
||||
subagent_conv_store = FileConversationStore(base_path=subagent_store_path)
|
||||
|
||||
# Derive a subagent-scoped spillover dir
|
||||
subagent_spillover = None
|
||||
if config.spillover_dir:
|
||||
subagent_spillover = str(Path(config.spillover_dir) / agent_id / subagent_instance)
|
||||
|
||||
subagent_node = event_loop_node_cls(
|
||||
event_bus=event_bus,
|
||||
judge=SubagentJudge(task=task, max_iterations=max_iter),
|
||||
config=LoopConfig(
|
||||
max_iterations=max_iter,
|
||||
max_tool_calls_per_turn=config.max_tool_calls_per_turn,
|
||||
tool_call_overflow_margin=config.tool_call_overflow_margin,
|
||||
max_context_tokens=config.max_context_tokens,
|
||||
stall_detection_threshold=config.stall_detection_threshold,
|
||||
max_tool_result_chars=config.max_tool_result_chars,
|
||||
spillover_dir=subagent_spillover,
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
conversation_store=subagent_conv_store,
|
||||
)
|
||||
|
||||
# Inject a unique GCU browser profile for this subagent
|
||||
_profile_token = None
|
||||
try:
|
||||
from gcu.browser.session import set_active_profile as _set_gcu_profile
|
||||
|
||||
_profile_token = _set_gcu_profile(f"{agent_id}-{subagent_instance}")
|
||||
except ImportError:
|
||||
pass # GCU tools not installed; no-op
|
||||
|
||||
try:
|
||||
logger.info("🚀 Starting subagent '%s' execution...", agent_id)
|
||||
start_time = time.time()
|
||||
result = await subagent_node.execute(subagent_ctx)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
separator = "-" * 60
|
||||
logger.info(
|
||||
"\n%s\n"
|
||||
"✅ SUBAGENT '%s' COMPLETED\n"
|
||||
"%s\n"
|
||||
"Success: %s\n"
|
||||
"Latency: %dms\n"
|
||||
"Tokens used: %s\n"
|
||||
"Output keys: %s\n"
|
||||
"%s",
|
||||
separator,
|
||||
agent_id,
|
||||
separator,
|
||||
result.success,
|
||||
latency_ms,
|
||||
result.tokens_used,
|
||||
list(result.output.keys()) if result.output else [],
|
||||
separator,
|
||||
)
|
||||
|
||||
result_json = {
|
||||
"message": (
|
||||
f"Sub-agent '{agent_id}' completed successfully"
|
||||
if result.success
|
||||
else f"Sub-agent '{agent_id}' failed: {result.error}"
|
||||
),
|
||||
"data": result.output,
|
||||
"reports": subagent_reports if subagent_reports else None,
|
||||
"metadata": {
|
||||
"agent_id": agent_id,
|
||||
"success": result.success,
|
||||
"tokens_used": result.tokens_used,
|
||||
"latency_ms": latency_ms,
|
||||
"report_count": len(subagent_reports),
|
||||
},
|
||||
}
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=json.dumps(result_json, indent=2, default=str),
|
||||
is_error=not result.success,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"\n" + "!" * 60 + "\n❌ SUBAGENT '%s' FAILED\nError: %s\n" + "!" * 60,
|
||||
agent_id,
|
||||
str(e),
|
||||
)
|
||||
result_json = {
|
||||
"message": f"Sub-agent '{agent_id}' raised exception: {e}",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"agent_id": agent_id,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
},
|
||||
}
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=json.dumps(result_json, indent=2),
|
||||
is_error=True,
|
||||
)
|
||||
finally:
|
||||
# Restore the GCU profile context
|
||||
if _profile_token is not None:
|
||||
from gcu.browser.session import _active_profile as _gcu_profile_var
|
||||
|
||||
_gcu_profile_var.reset(_profile_token)
|
||||
|
||||
# Stop the browser session for this subagent's profile
|
||||
if tool_executor is not None:
|
||||
_subagent_profile = f"{agent_id}-{subagent_instance}"
|
||||
try:
|
||||
_stop_use = ToolUse(
|
||||
id="gcu-cleanup",
|
||||
name="browser_stop",
|
||||
input={"profile": _subagent_profile},
|
||||
)
|
||||
_stop_result = tool_executor(_stop_use)
|
||||
if asyncio.iscoroutine(_stop_result) or asyncio.isfuture(_stop_result):
|
||||
await _stop_result
|
||||
except Exception as _gcu_exc:
|
||||
logger.warning(
|
||||
"GCU browser_stop failed for profile %r: %s",
|
||||
_subagent_profile,
|
||||
_gcu_exc,
|
||||
)
|
||||
@@ -0,0 +1,369 @@
|
||||
"""Synthetic tool builders for the event loop.
|
||||
|
||||
Factory functions that create ``Tool`` definitions for framework-level
|
||||
synthetic tools (set_output, ask_user, escalate, delegate, report_to_parent).
|
||||
Also includes the ``handle_set_output`` validation logic.
|
||||
|
||||
All functions are pure — they receive explicit parameters and return
|
||||
``Tool`` or ``ToolResult`` objects with no side effects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import Tool, ToolResult
|
||||
|
||||
|
||||
def build_ask_user_tool() -> Tool:
|
||||
"""Build the synthetic ask_user tool for explicit user-input requests.
|
||||
|
||||
Client-facing nodes call ask_user() when they need to pause and wait
|
||||
for user input. Text-only turns WITHOUT ask_user flow through without
|
||||
blocking, allowing progress updates and summaries to stream freely.
|
||||
"""
|
||||
return Tool(
|
||||
name="ask_user",
|
||||
description=(
|
||||
"You MUST call this tool whenever you need the user's response. "
|
||||
"Always call it after greeting the user, asking a question, or "
|
||||
"requesting approval. Do NOT call it for status updates or "
|
||||
"summaries that don't require a response. "
|
||||
"Always include 2-3 predefined options. The UI automatically "
|
||||
"appends an 'Other' free-text input after your options, so NEVER "
|
||||
"include catch-all options like 'Custom idea', 'Something else', "
|
||||
"'Other', or 'None of the above' — the UI handles that. "
|
||||
"When the question primarily needs a typed answer but you must "
|
||||
"include options, make one option signal that typing is expected "
|
||||
"(e.g. 'I\\'ll type my response'). This helps users discover the "
|
||||
"free-text input. "
|
||||
"The ONLY exception: omit options when the question demands a "
|
||||
"free-form answer the user must type out (e.g. 'Describe your "
|
||||
"agent idea', 'Paste the error message'). "
|
||||
'{"question": "What would you like to do?", "options": '
|
||||
'["Build a new agent", "Modify existing agent", "Run tests"]} '
|
||||
"Free-form example: "
|
||||
'{"question": "Describe the agent you want to build."}'
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question or prompt shown to the user.",
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"2-3 specific predefined choices. Include in most cases. "
|
||||
'Example: ["Option A", "Option B", "Option C"]. '
|
||||
"The UI always appends an 'Other' free-text input, so "
|
||||
"do NOT include catch-alls like 'Custom idea' or 'Other'. "
|
||||
"Omit ONLY when the user must type a free-form answer."
|
||||
),
|
||||
"minItems": 2,
|
||||
"maxItems": 3,
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_ask_user_multiple_tool() -> Tool:
|
||||
"""Build the synthetic ask_user_multiple tool for batched questions.
|
||||
|
||||
Queen-only tool that presents multiple questions at once so the user
|
||||
can answer them all in a single interaction rather than one at a time.
|
||||
"""
|
||||
return Tool(
|
||||
name="ask_user_multiple",
|
||||
description=(
|
||||
"Ask the user multiple questions at once. Use this instead of "
|
||||
"ask_user when you have 2 or more questions to ask in the same "
|
||||
"turn — it lets the user answer everything in one go rather than "
|
||||
"going back and forth. Each question can have its own predefined "
|
||||
"options (2-3 choices) or be free-form. The UI renders all "
|
||||
"questions together with a single Submit button. "
|
||||
"ALWAYS prefer this over ask_user when you have multiple things "
|
||||
"to clarify. "
|
||||
"IMPORTANT: Do NOT repeat the questions in your text response — "
|
||||
"the widget renders them. Keep your text to a brief intro only. "
|
||||
'{"questions": ['
|
||||
' {"id": "scope", "prompt": "What scope?", "options": ["Full", "Partial"]},'
|
||||
' {"id": "format", "prompt": "Output format?", "options": ["PDF", "CSV", "JSON"]},'
|
||||
' {"id": "details", "prompt": "Any special requirements?"}'
|
||||
"]}"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Short identifier for this question (used in the response)."
|
||||
),
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The question text shown to the user.",
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"2-3 predefined choices. The UI appends an "
|
||||
"'Other' free-text input automatically. "
|
||||
"Omit only when the user must type a free-form answer."
|
||||
),
|
||||
"minItems": 2,
|
||||
"maxItems": 3,
|
||||
},
|
||||
},
|
||||
"required": ["id", "prompt"],
|
||||
},
|
||||
"minItems": 2,
|
||||
"maxItems": 8,
|
||||
"description": "List of questions to present to the user.",
|
||||
},
|
||||
},
|
||||
"required": ["questions"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_set_output_tool(output_keys: list[str] | None) -> Tool | None:
|
||||
"""Build the synthetic set_output tool for explicit output declaration."""
|
||||
if not output_keys:
|
||||
return None
|
||||
return Tool(
|
||||
name="set_output",
|
||||
description=(
|
||||
"Set an output value for this node. Call once per output key. "
|
||||
"Use this for brief notes, counts, status, and file references — "
|
||||
"NOT for large data payloads. When a tool result was saved to a "
|
||||
"data file, pass the filename as the value "
|
||||
"(e.g. 'google_sheets_get_values_1.txt') so the next phase can "
|
||||
"load the full data. Values exceeding ~2000 characters are "
|
||||
"auto-saved to data files. "
|
||||
f"Valid keys: {output_keys}"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": f"Output key. Must be one of: {output_keys}",
|
||||
"enum": output_keys,
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The output value — a brief note, count, status, "
|
||||
"or data filename reference."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_escalate_tool() -> Tool:
|
||||
"""Build the synthetic escalate tool for worker -> queen handoff."""
|
||||
return Tool(
|
||||
name="escalate",
|
||||
description=(
|
||||
"Escalate to the queen when requesting user input, "
|
||||
"blocked by errors, missing "
|
||||
"credentials, or ambiguous constraints that require supervisor "
|
||||
"guidance. Include a concise reason and optional context. "
|
||||
"The node will pause until the queen injects guidance."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Short reason for escalation (e.g. 'Tool repeatedly failing')."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Optional diagnostic details for the queen.",
|
||||
},
|
||||
},
|
||||
"required": ["reason"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_delegate_tool(sub_agents: list[str], node_registry: dict[str, Any]) -> Tool | None:
|
||||
"""Build the synthetic delegate_to_sub_agent tool for subagent invocation.
|
||||
|
||||
Args:
|
||||
sub_agents: List of node IDs that can be invoked as subagents.
|
||||
node_registry: Map of node_id -> NodeSpec for looking up subagent descriptions.
|
||||
|
||||
Returns:
|
||||
Tool definition if sub_agents is non-empty, None otherwise.
|
||||
"""
|
||||
if not sub_agents:
|
||||
return None
|
||||
|
||||
agent_descriptions = []
|
||||
for agent_id in sub_agents:
|
||||
spec = node_registry.get(agent_id)
|
||||
if spec:
|
||||
desc = getattr(spec, "description", "(no description)")
|
||||
agent_descriptions.append(f"- {agent_id}: {desc}")
|
||||
else:
|
||||
agent_descriptions.append(f"- {agent_id}: (not found in registry)")
|
||||
|
||||
return Tool(
|
||||
name="delegate_to_sub_agent",
|
||||
description=(
|
||||
"Delegate a task to a specialized sub-agent. The sub-agent runs "
|
||||
"autonomously with read-only access to current memory and returns "
|
||||
"its result. Use this to parallelize work or leverage specialized capabilities.\n\n"
|
||||
"Available sub-agents:\n" + "\n".join(agent_descriptions)
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": f"The sub-agent to invoke. Must be one of: {sub_agents}",
|
||||
"enum": sub_agents,
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The task description for the sub-agent to execute. "
|
||||
"Be specific about what you want the sub-agent to do and "
|
||||
"what information to return."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "task"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_report_to_parent_tool() -> Tool:
|
||||
"""Build the synthetic report_to_parent tool for sub-agent progress reports.
|
||||
|
||||
Sub-agents call this to send one-way progress updates, partial findings,
|
||||
or status reports to the parent node (and external observers via event bus)
|
||||
without blocking execution.
|
||||
|
||||
When ``wait_for_response`` is True, the sub-agent blocks until the parent
|
||||
relays the user's response — used for escalation (e.g. login pages, CAPTCHAs).
|
||||
|
||||
When ``mark_complete`` is True, the sub-agent terminates immediately after
|
||||
sending the report — no need to call set_output for each output key.
|
||||
"""
|
||||
return Tool(
|
||||
name="report_to_parent",
|
||||
description=(
|
||||
"Send a report to the parent agent. By default this is fire-and-forget: "
|
||||
"the parent receives the report but does not respond. "
|
||||
"Set wait_for_response=true to BLOCK until the user replies — use this "
|
||||
"when you need human intervention (e.g. login pages, CAPTCHAs, "
|
||||
"authentication walls). The user's response is returned as the tool result. "
|
||||
"Set mark_complete=true to finish your task and terminate immediately "
|
||||
"after sending the report — use this when your findings are in the "
|
||||
"message/data fields and you don't need to call set_output."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "A human-readable status or progress message.",
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": "Optional structured data to include with the report.",
|
||||
},
|
||||
"wait_for_response": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, block execution until the user responds. "
|
||||
"Use for escalation scenarios requiring human intervention."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
"mark_complete": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, terminate the sub-agent immediately after sending "
|
||||
"this report. The report message and data are delivered to the "
|
||||
"parent as the final result. No set_output calls are needed."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def handle_set_output(
|
||||
tool_input: dict[str, Any],
|
||||
output_keys: list[str] | None,
|
||||
) -> ToolResult:
|
||||
"""Handle set_output tool call. Returns ToolResult (sync)."""
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
key = tool_input.get("key", "")
|
||||
value = tool_input.get("value", "")
|
||||
valid_keys = output_keys or []
|
||||
|
||||
# Recover from truncated JSON (max_tokens hit mid-argument).
|
||||
# The _raw key is set by litellm when json.loads fails.
|
||||
if not key and "_raw" in tool_input:
|
||||
raw = tool_input["_raw"]
|
||||
key_match = re.search(r'"key"\s*:\s*"(\w+)"', raw)
|
||||
if key_match:
|
||||
key = key_match.group(1)
|
||||
val_match = re.search(r'"value"\s*:\s*"', raw)
|
||||
if val_match:
|
||||
start = val_match.end()
|
||||
value = raw[start:].rstrip()
|
||||
for suffix in ('"}\n', '"}', '"'):
|
||||
if value.endswith(suffix):
|
||||
value = value[: -len(suffix)]
|
||||
break
|
||||
if key:
|
||||
logger.warning(
|
||||
"Recovered set_output args from truncated JSON: key=%s, value_len=%d",
|
||||
key,
|
||||
len(value),
|
||||
)
|
||||
# Re-inject so the caller sees proper key/value
|
||||
tool_input["key"] = key
|
||||
tool_input["value"] = value
|
||||
|
||||
if key not in valid_keys:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Invalid output key '{key}'. Valid keys: {valid_keys}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Output '{key}' set successfully.",
|
||||
is_error=False,
|
||||
)
|
||||
@@ -0,0 +1,542 @@
|
||||
"""Tool result handling: truncation, spillover, JSON preview, and execution.
|
||||
|
||||
Manages tool result size limits, file spillover for large results, and
|
||||
smart JSON previews. Also includes transient error classification and
|
||||
the context-window-exceeded error detector.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import ToolResult, ToolUse
|
||||
from framework.llm.stream_events import ToolCallEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pattern for detecting context-window-exceeded errors across LLM providers.
|
||||
_CONTEXT_TOO_LARGE_RE = re.compile(
|
||||
r"context.{0,20}(length|window|limit|size)|"
|
||||
r"too.{0,10}(long|large|many.{0,10}tokens)|"
|
||||
r"(exceed|exceeds|exceeded).{0,30}(limit|window|context|tokens)|"
|
||||
r"maximum.{0,20}token|prompt.{0,20}too.{0,10}long",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_context_too_large_error(exc: BaseException) -> bool:
|
||||
"""Detect whether an exception indicates the LLM input was too large."""
|
||||
cls = type(exc).__name__
|
||||
if "ContextWindow" in cls:
|
||||
return True
|
||||
return bool(_CONTEXT_TOO_LARGE_RE.search(str(exc)))
|
||||
|
||||
|
||||
def is_transient_error(exc: BaseException) -> bool:
|
||||
"""Classify whether an exception is transient (retryable) vs permanent.
|
||||
|
||||
Transient: network errors, rate limits, server errors, timeouts.
|
||||
Permanent: auth errors, bad requests, context window exceeded.
|
||||
"""
|
||||
try:
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
|
||||
transient_types: tuple[type[BaseException], ...] = (
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
InternalServerError,
|
||||
BadGatewayError,
|
||||
ServiceUnavailableError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
OSError,
|
||||
)
|
||||
except ImportError:
|
||||
transient_types = (TimeoutError, ConnectionError, OSError)
|
||||
|
||||
if isinstance(exc, transient_types):
|
||||
return True
|
||||
|
||||
# RuntimeError from StreamErrorEvent with "Stream error:" prefix
|
||||
if isinstance(exc, RuntimeError):
|
||||
error_str = str(exc).lower()
|
||||
transient_keywords = [
|
||||
"rate limit",
|
||||
"429",
|
||||
"timeout",
|
||||
"connection",
|
||||
"internal server",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"service unavailable",
|
||||
"bad gateway",
|
||||
"overloaded",
|
||||
"failed to parse tool call",
|
||||
]
|
||||
return any(kw in error_str for kw in transient_keywords)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def extract_json_metadata(parsed: Any, *, _depth: int = 0, _max_depth: int = 3) -> str:
|
||||
"""Return a concise structural summary of parsed JSON.
|
||||
|
||||
Reports key names, value types, and — crucially — array lengths so
|
||||
the LLM knows how much data exists beyond the preview.
|
||||
|
||||
Returns an empty string for simple scalars.
|
||||
"""
|
||||
if _depth >= _max_depth:
|
||||
if isinstance(parsed, dict):
|
||||
return f"dict with {len(parsed)} keys"
|
||||
if isinstance(parsed, list):
|
||||
return f"list of {len(parsed)} items"
|
||||
return type(parsed).__name__
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
if not parsed:
|
||||
return "empty dict"
|
||||
lines: list[str] = []
|
||||
indent = " " * (_depth + 1)
|
||||
for key, value in list(parsed.items())[:20]:
|
||||
if isinstance(value, list):
|
||||
line = f'{indent}"{key}": list of {len(value)} items'
|
||||
if value:
|
||||
first = value[0]
|
||||
if isinstance(first, dict):
|
||||
sample_keys = list(first.keys())[:10]
|
||||
line += f" (each item: dict with keys {sample_keys})"
|
||||
elif isinstance(first, list):
|
||||
line += f" (each item: list of {len(first)} elements)"
|
||||
lines.append(line)
|
||||
elif isinstance(value, dict):
|
||||
child = extract_json_metadata(value, _depth=_depth + 1, _max_depth=_max_depth)
|
||||
lines.append(f'{indent}"{key}": {child}')
|
||||
else:
|
||||
lines.append(f'{indent}"{key}": {type(value).__name__}')
|
||||
if len(parsed) > 20:
|
||||
lines.append(f"{indent}... and {len(parsed) - 20} more keys")
|
||||
return "\n".join(lines)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
if not parsed:
|
||||
return "empty list"
|
||||
desc = f"list of {len(parsed)} items"
|
||||
first = parsed[0]
|
||||
if isinstance(first, dict):
|
||||
sample_keys = list(first.keys())[:10]
|
||||
desc += f" (each item: dict with keys {sample_keys})"
|
||||
elif isinstance(first, list):
|
||||
desc += f" (each item: list of {len(first)} elements)"
|
||||
return desc
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def build_json_preview(parsed: Any, *, max_chars: int = 5000) -> str | None:
|
||||
"""Build a smart preview of parsed JSON, truncating large arrays.
|
||||
|
||||
Shows first 3 + last 1 items of large arrays with explicit count
|
||||
markers so the LLM cannot mistake the preview for the full dataset.
|
||||
|
||||
Returns ``None`` if no truncation was needed (no large arrays).
|
||||
"""
|
||||
_LARGE_ARRAY_THRESHOLD = 10
|
||||
|
||||
def _truncate_arrays(obj: Any) -> tuple[Any, bool]:
|
||||
"""Return (truncated_copy, was_truncated)."""
|
||||
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
|
||||
n = len(obj)
|
||||
head = obj[:3]
|
||||
tail = obj[-1:]
|
||||
marker = f"... ({n - 4} more items omitted, {n} total) ..."
|
||||
return head + [marker] + tail, True
|
||||
if isinstance(obj, dict):
|
||||
changed = False
|
||||
out: dict[str, Any] = {}
|
||||
for k, v in obj.items():
|
||||
new_v, did = _truncate_arrays(v)
|
||||
out[k] = new_v
|
||||
changed = changed or did
|
||||
return (out, True) if changed else (obj, False)
|
||||
return obj, False
|
||||
|
||||
preview_obj, was_truncated = _truncate_arrays(parsed)
|
||||
if not was_truncated:
|
||||
return None # No large arrays — caller should use raw slicing
|
||||
|
||||
try:
|
||||
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if len(result) > max_chars:
|
||||
# Even 3+1 items too big — try just 1 item
|
||||
def _minimal_arrays(obj: Any) -> Any:
|
||||
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
|
||||
n = len(obj)
|
||||
return obj[:1] + [f"... ({n - 1} more items omitted, {n} total) ..."]
|
||||
if isinstance(obj, dict):
|
||||
return {k: _minimal_arrays(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
preview_obj = _minimal_arrays(parsed)
|
||||
try:
|
||||
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if len(result) > max_chars:
|
||||
result = result[:max_chars] + "…"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def truncate_tool_result(
|
||||
result: ToolResult,
|
||||
tool_name: str,
|
||||
*,
|
||||
max_tool_result_chars: int,
|
||||
spillover_dir: str | None,
|
||||
next_spill_filename_fn: Any, # Callable[[str], str]
|
||||
) -> ToolResult:
|
||||
"""Persist tool result to file and optionally truncate for context.
|
||||
|
||||
When *spillover_dir* is configured, EVERY non-error tool result is
|
||||
saved to a file (short filename like ``web_search_1.txt``). A
|
||||
``[Saved to '...']`` annotation is appended so the reference
|
||||
survives pruning and compaction.
|
||||
|
||||
- Small results (≤ limit): full content kept + file annotation
|
||||
- Large results (> limit): preview + file reference
|
||||
- Errors: pass through unchanged
|
||||
- load_data results: truncate with pagination hint (no re-spill)
|
||||
"""
|
||||
limit = max_tool_result_chars
|
||||
|
||||
# Errors always pass through unchanged
|
||||
if result.is_error:
|
||||
return result
|
||||
|
||||
# load_data reads FROM spilled files — never re-spill (circular).
|
||||
# Just truncate with a pagination hint if the result is too large.
|
||||
if tool_name == "load_data":
|
||||
if limit <= 0 or len(result.content) <= limit:
|
||||
return result # Small load_data result — pass through as-is
|
||||
# Large load_data result — truncate with smart preview
|
||||
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
|
||||
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
try:
|
||||
parsed_ld = json.loads(result.content)
|
||||
metadata_str = extract_json_metadata(parsed_ld)
|
||||
smart_preview = build_json_preview(parsed_ld, max_chars=PREVIEW_CAP)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
header = (
|
||||
f"[{tool_name} result: {len(result.content):,} chars — "
|
||||
f"too large for context. Use offset_bytes/limit_bytes "
|
||||
f"parameters to read smaller chunks.]"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\n\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
"\n\nWARNING: This is an INCOMPLETE preview. Do NOT draw conclusions or counts from it."
|
||||
)
|
||||
|
||||
truncated = f"{header}\n\nPreview (small sample only):\n{preview_block}"
|
||||
logger.info(
|
||||
"%s result truncated: %d → %d chars (use offset/limit to paginate)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
len(truncated),
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
|
||||
spill_dir = spillover_dir
|
||||
if spill_dir:
|
||||
spill_path = Path(spill_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
filename = next_spill_filename_fn(tool_name)
|
||||
|
||||
# Pretty-print JSON content so load_data's line-based
|
||||
# pagination works correctly.
|
||||
write_content = result.content
|
||||
parsed_json: Any = None # track for metadata extraction
|
||||
try:
|
||||
parsed_json = json.loads(result.content)
|
||||
write_content = json.dumps(parsed_json, indent=2, ensure_ascii=False)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass # Not JSON — write as-is
|
||||
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
|
||||
if limit > 0 and len(result.content) > limit:
|
||||
# Large result: build a small, metadata-rich preview so the
|
||||
# LLM cannot mistake it for the complete dataset.
|
||||
PREVIEW_CAP = 5000
|
||||
|
||||
# Extract structural metadata (array lengths, key names)
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
if parsed_json is not None:
|
||||
metadata_str = extract_json_metadata(parsed_json)
|
||||
smart_preview = build_json_preview(parsed_json, max_chars=PREVIEW_CAP)
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
# Assemble header with structural info + warning
|
||||
header = (
|
||||
f"[Result from {tool_name}: {len(result.content):,} chars — "
|
||||
f"too large for context, saved to '{filename}'.]\n"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
f"\n\nWARNING: The preview below is INCOMPLETE. "
|
||||
f"Do NOT draw conclusions or counts from it. "
|
||||
f"Use load_data(filename='{filename}') to read the "
|
||||
f"full data before analysis."
|
||||
)
|
||||
|
||||
content = f"{header}\n\nPreview (small sample only):\n{preview_block}"
|
||||
logger.info(
|
||||
"Tool result spilled to file: %s (%d chars → %s)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
filename,
|
||||
)
|
||||
else:
|
||||
# Small result: keep full content + annotation
|
||||
content = f"{result.content}\n\n[Saved to '{filename}']"
|
||||
logger.info(
|
||||
"Tool result saved to file: %s (%d chars → %s)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
filename,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=content,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
|
||||
# No spillover_dir — truncate in-place if needed
|
||||
if limit > 0 and len(result.content) > limit:
|
||||
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
|
||||
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
try:
|
||||
parsed_inline = json.loads(result.content)
|
||||
metadata_str = extract_json_metadata(parsed_inline)
|
||||
smart_preview = build_json_preview(parsed_inline, max_chars=PREVIEW_CAP)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
header = (
|
||||
f"[Result from {tool_name}: {len(result.content):,} chars — "
|
||||
f"truncated to fit context budget.]"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\n\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
"\n\nWARNING: This is an INCOMPLETE preview. "
|
||||
"Do NOT draw conclusions or counts from the preview alone."
|
||||
)
|
||||
|
||||
truncated = f"{header}\n\n{preview_block}"
|
||||
logger.info(
|
||||
"Tool result truncated in-place: %s (%d → %d chars)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
len(truncated),
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_executor: Any, # Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None
|
||||
tc: ToolCallEvent,
|
||||
timeout: float,
|
||||
skill_dirs: list[str] | None = None,
|
||||
) -> ToolResult:
|
||||
"""Execute a tool call, handling both sync and async executors.
|
||||
|
||||
Applies ``tool_call_timeout_seconds`` to prevent hung MCP servers
|
||||
from blocking the event loop indefinitely. The initial executor
|
||||
call is offloaded to a thread pool so that sync executors don't
|
||||
freeze the event loop.
|
||||
"""
|
||||
if tool_executor is None:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"No tool executor configured for '{tc.tool_name}'",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
skill_dirs = skill_dirs or []
|
||||
skill_read_tools = {"view_file", "load_data", "read_file"}
|
||||
if tc.tool_name in skill_read_tools and skill_dirs:
|
||||
raw_path = tc.tool_input.get("path", "")
|
||||
if raw_path:
|
||||
resolved = Path(raw_path).resolve(strict=False)
|
||||
resolved_roots = [Path(skill_dir).resolve(strict=False) for skill_dir in skill_dirs]
|
||||
if any(resolved.is_relative_to(root) for root in resolved_roots):
|
||||
try:
|
||||
content = resolved.read_text(encoding="utf-8")
|
||||
except Exception as exc:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"Could not read skill resource '{raw_path}': {exc}",
|
||||
is_error=True,
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=content,
|
||||
is_skill_content=resolved.name == "SKILL.md",
|
||||
)
|
||||
|
||||
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
|
||||
|
||||
async def _run() -> ToolResult:
|
||||
# Offload the executor call to a thread. Sync MCP executors
|
||||
# block on future.result() — running in a thread keeps the
|
||||
# event loop free so asyncio.wait_for can fire the timeout.
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, tool_executor, tool_use)
|
||||
# Async executors return a coroutine — await it on the loop
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
try:
|
||||
if timeout > 0:
|
||||
result = await asyncio.wait_for(_run(), timeout=timeout)
|
||||
else:
|
||||
result = await _run()
|
||||
except TimeoutError:
|
||||
logger.warning("Tool '%s' timed out after %.0fs", tc.tool_name, timeout)
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
f"Tool '{tc.tool_name}' timed out after {timeout:.0f}s. "
|
||||
"The operation took too long and was cancelled. "
|
||||
"Try a simpler request or a different approach."
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def record_learning(key: str, value: Any, spillover_dir: str | None) -> None:
|
||||
"""Append a set_output value to adapt.md as a learning entry.
|
||||
|
||||
Called at set_output time — the moment knowledge is produced — so that
|
||||
adapt.md accumulates the agent's outputs across the session. Since
|
||||
adapt.md is injected into the system prompt, these persist through
|
||||
any compaction.
|
||||
"""
|
||||
if not spillover_dir:
|
||||
return
|
||||
try:
|
||||
adapt_path = Path(spillover_dir) / "adapt.md"
|
||||
adapt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = adapt_path.read_text(encoding="utf-8") if adapt_path.exists() else ""
|
||||
|
||||
if "## Outputs" not in content:
|
||||
content += "\n\n## Outputs\n"
|
||||
|
||||
# Truncate long values for memory (full value is in shared memory)
|
||||
v_str = str(value)
|
||||
if len(v_str) > 500:
|
||||
v_str = v_str[:500] + "…"
|
||||
|
||||
entry = f"- {key}: {v_str}\n"
|
||||
|
||||
# Replace existing entry for same key (update, not duplicate)
|
||||
lines = content.splitlines(keepends=True)
|
||||
replaced = False
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith(f"- {key}:"):
|
||||
lines[i] = entry
|
||||
replaced = True
|
||||
break
|
||||
if replaced:
|
||||
content = "".join(lines)
|
||||
else:
|
||||
content += entry
|
||||
|
||||
adapt_path.write_text(content, encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record learning for key=%s: %s", key, e)
|
||||
|
||||
|
||||
def next_spill_filename(tool_name: str, counter: int) -> str:
|
||||
"""Return a short, monotonic filename for a tool result spill."""
|
||||
# Shorten common tool name prefixes to save tokens
|
||||
short = tool_name.removeprefix("tool_").removeprefix("mcp_")
|
||||
return f"{short}_{counter}.txt"
|
||||
|
||||
|
||||
def restore_spill_counter(spillover_dir: str | None) -> int:
|
||||
"""Scan spillover_dir for existing spill files and return the max counter.
|
||||
|
||||
Returns the highest spill number found (or 0 if none).
|
||||
"""
|
||||
if not spillover_dir:
|
||||
return 0
|
||||
spill_path = Path(spillover_dir)
|
||||
if not spill_path.is_dir():
|
||||
return 0
|
||||
max_n = 0
|
||||
for f in spill_path.iterdir():
|
||||
if not f.is_file():
|
||||
continue
|
||||
m = re.search(r"_(\d+)\.txt$", f.name)
|
||||
if m:
|
||||
max_n = max(max_n, int(m.group(1)))
|
||||
return max_n
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Shared types and state containers for the event loop package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from framework.graph.conversation import ConversationStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TriggerEvent:
|
||||
"""A framework-level trigger signal (timer tick or webhook hit)."""
|
||||
|
||||
trigger_type: str
|
||||
source_id: str
|
||||
payload: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeVerdict:
|
||||
"""Result of judge evaluation for the event loop."""
|
||||
|
||||
action: Literal["ACCEPT", "RETRY", "ESCALATE"]
|
||||
# None = no evaluation happened (skip_judge, tool-continue); not logged.
|
||||
# "" = evaluated but no feedback; logged with default text.
|
||||
# "..." = evaluated with feedback; logged as-is.
|
||||
feedback: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class JudgeProtocol(Protocol):
|
||||
"""Protocol for event-loop judges."""
|
||||
|
||||
async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopConfig:
|
||||
"""Configuration for the event loop."""
|
||||
|
||||
max_iterations: int = 50
|
||||
max_tool_calls_per_turn: int = 30
|
||||
judge_every_n_turns: int = 1
|
||||
stall_detection_threshold: int = 3
|
||||
stall_similarity_threshold: float = 0.85
|
||||
max_context_tokens: int = 32_000
|
||||
store_prefix: str = ""
|
||||
|
||||
# Overflow margin for max_tool_calls_per_turn. Tool calls are only
|
||||
# discarded when the count exceeds max_tool_calls_per_turn * (1 + margin).
|
||||
tool_call_overflow_margin: float = 0.5
|
||||
|
||||
# Tool result context management.
|
||||
max_tool_result_chars: int = 30_000
|
||||
spillover_dir: str | None = None
|
||||
|
||||
# set_output value spilling.
|
||||
max_output_value_chars: int = 2_000
|
||||
|
||||
# Stream retry.
|
||||
max_stream_retries: int = 3
|
||||
stream_retry_backoff_base: float = 2.0
|
||||
stream_retry_max_delay: float = 60.0
|
||||
|
||||
# Tool doom loop detection.
|
||||
tool_doom_loop_threshold: int = 3
|
||||
|
||||
# Client-facing auto-block grace period.
|
||||
cf_grace_turns: int = 1
|
||||
tool_doom_loop_enabled: bool = True
|
||||
|
||||
# Per-tool-call timeout.
|
||||
tool_call_timeout_seconds: float = 60.0
|
||||
|
||||
# Subagent delegation timeout.
|
||||
subagent_timeout_seconds: float = 600.0
|
||||
|
||||
# Lifecycle hooks.
|
||||
hooks: dict[str, list] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.hooks is None:
|
||||
object.__setattr__(self, "hooks", {})
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookContext:
|
||||
"""Context passed to every lifecycle hook."""
|
||||
|
||||
event: str
|
||||
trigger: str | None
|
||||
system_prompt: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""What a hook may return to modify node state."""
|
||||
|
||||
system_prompt: str | None = None
|
||||
inject: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputAccumulator:
|
||||
"""Accumulates output key-value pairs with optional write-through persistence."""
|
||||
|
||||
values: dict[str, Any] = field(default_factory=dict)
|
||||
store: ConversationStore | None = None
|
||||
spillover_dir: str | None = None
|
||||
max_value_chars: int = 0
|
||||
|
||||
async def set(self, key: str, value: Any) -> None:
|
||||
"""Set a key-value pair, auto-spilling large values to files."""
|
||||
value = self._auto_spill(key, value)
|
||||
self.values[key] = value
|
||||
if self.store:
|
||||
cursor = await self.store.read_cursor() or {}
|
||||
outputs = cursor.get("outputs", {})
|
||||
outputs[key] = value
|
||||
cursor["outputs"] = outputs
|
||||
await self.store.write_cursor(cursor)
|
||||
|
||||
def _auto_spill(self, key: str, value: Any) -> Any:
|
||||
"""Save large values to a file and return a reference string."""
|
||||
if self.max_value_chars <= 0 or not self.spillover_dir:
|
||||
return value
|
||||
|
||||
val_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
||||
if len(val_str) <= self.max_value_chars:
|
||||
return value
|
||||
|
||||
spill_path = Path(self.spillover_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
write_content = (
|
||||
json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (spill_path / filename).stat().st_size
|
||||
logger.info(
|
||||
"set_output value auto-spilled: key=%s, %d chars -> %s (%d bytes)",
|
||||
key,
|
||||
len(val_str),
|
||||
filename,
|
||||
file_size,
|
||||
)
|
||||
return (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') "
|
||||
f"to access full data.]"
|
||||
)
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
return self.values.get(key)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return dict(self.values)
|
||||
|
||||
def has_all_keys(self, required: list[str]) -> bool:
|
||||
return all(key in self.values and self.values[key] is not None for key in required)
|
||||
|
||||
@classmethod
|
||||
async def restore(cls, store: ConversationStore) -> OutputAccumulator:
|
||||
cursor = await store.read_cursor()
|
||||
values = {}
|
||||
if cursor and "outputs" in cursor:
|
||||
values = cursor["outputs"]
|
||||
return cls(values=values, store=store)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HookContext",
|
||||
"HookResult",
|
||||
"JudgeProtocol",
|
||||
"JudgeVerdict",
|
||||
"LoopConfig",
|
||||
"OutputAccumulator",
|
||||
"TriggerEvent",
|
||||
]
|
||||
+523
-2191
File diff suppressed because it is too large
Load Diff
@@ -154,6 +154,9 @@ class GraphExecutor:
|
||||
iteration_metadata_provider: Callable | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
context_warn_ratio: float | None = None,
|
||||
batch_init_nudge: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -181,6 +184,9 @@ class GraphExecutor:
|
||||
system prompt (for phase switching)
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
context_warn_ratio: Token usage ratio to trigger DS-13 preservation warning
|
||||
batch_init_nudge: System prompt nudge for DS-12 batch auto-detection
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -204,6 +210,9 @@ class GraphExecutor:
|
||||
self.iteration_metadata_provider = iteration_metadata_provider
|
||||
self.skills_catalog_prompt = skills_catalog_prompt
|
||||
self.protocols_prompt = protocols_prompt
|
||||
self.skill_dirs: list[str] = skill_dirs or []
|
||||
self.context_warn_ratio: float | None = context_warn_ratio
|
||||
self.batch_init_nudge: str | None = batch_init_nudge
|
||||
|
||||
if protocols_prompt:
|
||||
self.logger.info(
|
||||
@@ -1845,6 +1854,9 @@ class GraphExecutor:
|
||||
|
||||
existing_underscore = [k for k in memory._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
# Only inject into read_keys when it was already non-empty — an empty
|
||||
# read_keys means "allow all reads" and injecting skill keys would
|
||||
# inadvertently restrict reads to skill keys only.
|
||||
for k in extra_keys:
|
||||
if read_keys and k not in read_keys:
|
||||
read_keys.append(k)
|
||||
@@ -1899,6 +1911,9 @@ class GraphExecutor:
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
default_skill_warn_ratio=self.context_warn_ratio,
|
||||
default_skill_batch_nudge=self.batch_init_nudge,
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
|
||||
@@ -43,8 +43,11 @@ Follow these rules for reliable, efficient browser interaction.
|
||||
`browser_snapshot` separately after every action.
|
||||
Only call `browser_snapshot` when you need a fresh view without
|
||||
performing an action, or after setting `auto_snapshot=false`.
|
||||
- Do NOT use `browser_screenshot` for reading text content
|
||||
— it produces huge base64 images with no searchable text.
|
||||
- Do NOT use `browser_screenshot` to read text — use
|
||||
`browser_snapshot` for that (compact, searchable, fast).
|
||||
- DO use `browser_screenshot` when you need visual context:
|
||||
charts, images, canvas elements, layout verification, or when
|
||||
the snapshot doesn't capture what you need.
|
||||
- Only fall back to `browser_get_text` for extracting specific
|
||||
small elements by CSS selector.
|
||||
|
||||
|
||||
@@ -167,14 +167,6 @@ class Goal(BaseModel):
|
||||
|
||||
return met_weight >= total_weight * 0.9 # 90% threshold
|
||||
|
||||
def check_constraint(self, constraint_id: str, value: Any) -> bool:
|
||||
"""Check if a specific constraint is satisfied."""
|
||||
for c in self.constraints:
|
||||
if c.id == constraint_id:
|
||||
# This would be expanded with actual evaluation logic
|
||||
return True
|
||||
return True
|
||||
|
||||
def to_prompt_context(self) -> str:
|
||||
"""Generate context string for LLM prompts.
|
||||
|
||||
|
||||
@@ -568,6 +568,11 @@ class NodeContext:
|
||||
# Skill system prompts — injected by the skill discovery pipeline
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
protocols_prompt: str = "" # Default skill operational protocols
|
||||
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
|
||||
# DS-12: batch auto-detection nudge appended to system prompt when input looks like a batch
|
||||
default_skill_batch_nudge: str | None = None
|
||||
# DS-13: token usage ratio at which to inject a context preservation warning
|
||||
default_skill_warn_ratio: float | None = None
|
||||
|
||||
# Per-iteration metadata provider — when set, EventLoopNode merges
|
||||
# the returned dict into node_loop_iteration event data. Used by
|
||||
|
||||
@@ -152,6 +152,8 @@ def compose_system_prompt(
|
||||
accounts_prompt: str | None = None,
|
||||
skills_catalog_prompt: str | None = None,
|
||||
protocols_prompt: str | None = None,
|
||||
execution_preamble: str | None = None,
|
||||
node_type_preamble: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the multi-layer system prompt.
|
||||
|
||||
@@ -162,6 +164,10 @@ def compose_system_prompt(
|
||||
accounts_prompt: Connected accounts block (sits between identity and narrative).
|
||||
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
|
||||
protocols_prompt: Default skill operational protocols section.
|
||||
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
|
||||
(prepended before focus so the LLM knows its pipeline scope).
|
||||
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
|
||||
best-practices prompt (prepended before focus).
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present, plus current datetime.
|
||||
@@ -188,6 +194,15 @@ def compose_system_prompt(
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
|
||||
# Execution scope preamble (worker nodes — tells the LLM it is one
|
||||
# step in a multi-node pipeline and should not overreach)
|
||||
if execution_preamble:
|
||||
parts.append(f"\n{execution_preamble}")
|
||||
|
||||
# Node-type preamble (e.g. GCU browser best-practices)
|
||||
if node_type_preamble:
|
||||
parts.append(f"\n{node_type_preamble}")
|
||||
|
||||
# Layer 3: Focus (current phase directive)
|
||||
if focus_prompt:
|
||||
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
|
||||
|
||||
@@ -228,10 +228,6 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
|
||||
return func(*args, **keywords)
|
||||
|
||||
def visit_Index(self, node: ast.Index) -> Any:
|
||||
# Python < 3.9
|
||||
return self.visit(node.value)
|
||||
|
||||
|
||||
def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,706 @@
|
||||
"""Antigravity (Google internal Cloud Code Assist) LLM provider.
|
||||
|
||||
Antigravity is Google's unified gateway API that routes requests to Gemini,
|
||||
Claude, and GPT-OSS models through a single Gemini-style interface. It is
|
||||
NOT the public ``generativelanguage.googleapis.com`` API.
|
||||
|
||||
Authentication uses Google OAuth2. Token refresh is done directly with the
|
||||
OAuth client secret — no local proxy required.
|
||||
|
||||
Credential sources (checked in order):
|
||||
1. ``~/.hive/antigravity-accounts.json`` (native OAuth implementation)
|
||||
2. Antigravity IDE SQLite state DB (macOS / Linux)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Fallback order: daily sandbox → autopush sandbox → production
|
||||
_ENDPOINTS = [
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com",
|
||||
"https://autopush-cloudcode-pa.sandbox.googleapis.com",
|
||||
"https://cloudcode-pa.googleapis.com",
|
||||
]
|
||||
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
|
||||
_TOKEN_REFRESH_BUFFER_SECS = 60
|
||||
|
||||
# Credentials file in ~/.hive/ (native implementation)
|
||||
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
_IDE_STATE_DB_MAC = (
|
||||
Path.home()
|
||||
/ "Library"
|
||||
/ "Application Support"
|
||||
/ "Antigravity"
|
||||
/ "User"
|
||||
/ "globalStorage"
|
||||
/ "state.vscdb"
|
||||
)
|
||||
_IDE_STATE_DB_LINUX = (
|
||||
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
|
||||
)
|
||||
_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
|
||||
|
||||
_BASE_HEADERS: dict[str, str] = {
|
||||
# Mimic the Antigravity Electron app so the API accepts the request.
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Antigravity/1.18.3 Chrome/138.0.7204.235 "
|
||||
"Electron/37.3.1 Safari/537.36"
|
||||
),
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"Client-Metadata": '{"ideType":"ANTIGRAVITY","platform":"MACOS","pluginType":"GEMINI"}',
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_from_json_file() -> tuple[str | None, str | None, str, float]:
|
||||
"""Read credentials from JSON accounts file.
|
||||
|
||||
Reads from ~/.hive/antigravity-accounts.json.
|
||||
|
||||
Returns ``(access_token | None, refresh_token | None, project_id, expires_at)``.
|
||||
``expires_at`` is a Unix timestamp (seconds); 0.0 means unknown.
|
||||
"""
|
||||
if not _ACCOUNTS_FILE.exists():
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
try:
|
||||
with open(_ACCOUNTS_FILE, encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
logger.debug("Failed to read Antigravity accounts file: %s", exc)
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
|
||||
accounts = data.get("accounts", [])
|
||||
if not accounts:
|
||||
return None, None, _DEFAULT_PROJECT_ID, 0.0
|
||||
|
||||
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
|
||||
schema_version = data.get("schemaVersion", 1)
|
||||
|
||||
if schema_version >= 4:
|
||||
# V4 schema: refresh = "refreshToken|projectId[|managedProjectId]"
|
||||
refresh_str = account.get("refresh", "")
|
||||
parts = refresh_str.split("|") if refresh_str else []
|
||||
refresh_token: str | None = parts[0] if parts else None
|
||||
project_id = parts[1] if len(parts) >= 2 and parts[1] else _DEFAULT_PROJECT_ID
|
||||
|
||||
access_token: str | None = account.get("access")
|
||||
expires_ms: int = account.get("expires", 0)
|
||||
expires_at = float(expires_ms) / 1000.0 if expires_ms else 0.0
|
||||
|
||||
# Treat near-expiry tokens as absent so _ensure_token() triggers a refresh.
|
||||
if access_token and expires_at and time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
|
||||
access_token = None
|
||||
expires_at = 0.0
|
||||
|
||||
return access_token, refresh_token, project_id, expires_at
|
||||
else:
|
||||
# V1–V3 schema: plain accessToken / refreshToken fields
|
||||
access_token = account.get("accessToken")
|
||||
refresh_token = account.get("refreshToken")
|
||||
# Estimate expiry from last_refresh + 1 h
|
||||
last_refresh_str: str | None = data.get("last_refresh")
|
||||
expires_at = 0.0
|
||||
if last_refresh_str:
|
||||
try:
|
||||
from datetime import datetime # noqa: PLC0415
|
||||
|
||||
ts = datetime.fromisoformat(last_refresh_str.replace("Z", "+00:00")).timestamp()
|
||||
expires_at = ts + 3600.0
|
||||
if time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
|
||||
access_token = None
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return access_token, refresh_token, _DEFAULT_PROJECT_ID, expires_at
|
||||
|
||||
|
||||
def _load_from_ide_db() -> tuple[str | None, str | None, float]:
|
||||
"""Extract ``(access_token, refresh_token, expires_at)`` from the IDE SQLite DB."""
|
||||
import base64 # noqa: PLC0415
|
||||
import sqlite3 # noqa: PLC0415
|
||||
|
||||
for db_path in (_IDE_STATE_DB_MAC, _IDE_STATE_DB_LINUX):
|
||||
if not db_path.exists():
|
||||
continue
|
||||
try:
|
||||
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
try:
|
||||
row = con.execute(
|
||||
"SELECT value FROM ItemTable WHERE key = ?",
|
||||
(_IDE_STATE_DB_KEY,),
|
||||
).fetchone()
|
||||
finally:
|
||||
con.close()
|
||||
if not row:
|
||||
continue
|
||||
|
||||
blob = base64.b64decode(row[0])
|
||||
candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
for candidate in candidates:
|
||||
try:
|
||||
padded = candidate + b"=" * (-len(candidate) % 4)
|
||||
inner = base64.urlsafe_b64decode(padded)
|
||||
except Exception:
|
||||
continue
|
||||
if not access_token:
|
||||
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
access_token = m.group(0).decode("ascii")
|
||||
if not refresh_token:
|
||||
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
refresh_token = m.group(0).decode("ascii")
|
||||
if access_token and refresh_token:
|
||||
break
|
||||
|
||||
if access_token:
|
||||
# Estimate expiry from DB mtime (IDE refreshes while running)
|
||||
mtime = db_path.stat().st_mtime
|
||||
expires_at = mtime + 3600.0
|
||||
return access_token, refresh_token, expires_at
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
|
||||
continue
|
||||
return None, None, 0.0
|
||||
|
||||
|
||||
def _do_token_refresh(refresh_token: str) -> tuple[str, float] | None:
|
||||
"""POST to Google OAuth endpoint and return ``(new_access_token, expires_at)``.
|
||||
|
||||
The client secret is sourced via ``get_antigravity_client_secret()`` (env var,
|
||||
config file, or npm package fallback). When unavailable the refresh is attempted
|
||||
without it — Google will reject it for web-app clients, but the npm fallback in
|
||||
``get_antigravity_client_secret()`` should ensure the secret is found at runtime.
|
||||
|
||||
Returns None when the HTTP request fails.
|
||||
"""
|
||||
from framework.config import get_antigravity_client_secret # noqa: PLC0415
|
||||
|
||||
client_secret = get_antigravity_client_secret()
|
||||
if not client_secret:
|
||||
logger.debug(
|
||||
"Antigravity client secret not configured — attempting refresh without it. "
|
||||
"Set ANTIGRAVITY_CLIENT_SECRET or run quickstart to configure."
|
||||
)
|
||||
|
||||
import urllib.error # noqa: PLC0415
|
||||
import urllib.parse # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
from framework.config import get_antigravity_client_id # noqa: PLC0415
|
||||
|
||||
params: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": get_antigravity_client_id(),
|
||||
}
|
||||
if client_secret:
|
||||
params["client_secret"] = client_secret
|
||||
body = urllib.parse.urlencode(params).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
_TOKEN_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
|
||||
payload = json.loads(resp.read())
|
||||
access_token: str = payload["access_token"]
|
||||
expires_in: int = payload.get("expires_in", 3600)
|
||||
logger.debug("Antigravity token refreshed successfully")
|
||||
return access_token, time.time() + expires_in
|
||||
except Exception as exc:
|
||||
logger.debug("Antigravity token refresh failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message conversion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _clean_tool_name(name: str) -> str:
|
||||
"""Sanitize a tool name for the Antigravity function-calling schema."""
|
||||
name = re.sub(r"[/\s]", "_", name)
|
||||
if name and not (name[0].isalpha() or name[0] == "_"):
|
||||
name = "_" + name
|
||||
return name[:64]
|
||||
|
||||
|
||||
def _to_gemini_contents(
|
||||
messages: list[dict[str, Any]],
|
||||
thought_sigs: dict[str, str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-format messages to Gemini-style ``contents`` array."""
|
||||
# Pre-build a map tool_call_id → function_name from assistant messages.
|
||||
# Tool result messages (role="tool") only carry tool_call_id, not the name,
|
||||
# but Gemini requires functionResponse.name to match the functionCall.name.
|
||||
tc_id_to_name: dict[str, str] = {}
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
tc_id = tc.get("id")
|
||||
fn_name = tc.get("function", {}).get("name", "")
|
||||
if tc_id and fn_name:
|
||||
tc_id_to_name[tc_id] = fn_name
|
||||
|
||||
contents: list[dict[str, Any]] = []
|
||||
# Consecutive tool-result messages must be batched into one user turn.
|
||||
pending_tool_parts: list[dict[str, Any]] = []
|
||||
|
||||
def _flush_tool_results() -> None:
|
||||
if pending_tool_parts:
|
||||
contents.append({"role": "user", "parts": list(pending_tool_parts)})
|
||||
pending_tool_parts.clear()
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
continue # Handled via systemInstruction, not in contents.
|
||||
|
||||
if role == "tool":
|
||||
# OpenAI tool result → Gemini functionResponse part.
|
||||
result_str = content if isinstance(content, str) else str(content or "")
|
||||
tc_id = msg.get("tool_call_id", "")
|
||||
# Look up function name from the pre-built map; fall back to msg.name.
|
||||
fn_name = tc_id_to_name.get(tc_id) or msg.get("name", "")
|
||||
pending_tool_parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": fn_name,
|
||||
"id": tc_id,
|
||||
"response": {"content": result_str},
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
_flush_tool_results()
|
||||
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
parts: list[dict[str, Any]] = []
|
||||
|
||||
if isinstance(content, str) and content:
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if text:
|
||||
parts.append({"text": text})
|
||||
# Other block types (image_url etc.) skipped.
|
||||
|
||||
# Assistant messages may carry OpenAI-style tool_calls.
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
fn = tc.get("function", {})
|
||||
try:
|
||||
args = json.loads(fn.get("arguments", "{}") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
tc_id = tc.get("id", str(uuid.uuid4()))
|
||||
fc_part: dict[str, Any] = {
|
||||
"functionCall": {
|
||||
"name": fn.get("name", ""),
|
||||
"args": args,
|
||||
"id": tc_id,
|
||||
}
|
||||
}
|
||||
if thought_sigs:
|
||||
sig = thought_sigs.get(tc_id, "")
|
||||
if sig:
|
||||
fc_part["thoughtSignature"] = sig # part-level, not inside functionCall
|
||||
parts.append(fc_part)
|
||||
|
||||
if parts:
|
||||
contents.append({"role": gemini_role, "parts": parts})
|
||||
|
||||
_flush_tool_results()
|
||||
|
||||
# Gemini requires the first turn to be a user turn. Drop any leading
|
||||
# model messages so the API doesn't reject with a 400.
|
||||
while contents and contents[0].get("role") == "model":
|
||||
contents.pop(0)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _map_finish_reason(reason: str) -> str:
|
||||
return {"STOP": "stop", "MAX_TOKENS": "max_tokens", "OTHER": "tool_use"}.get(
|
||||
(reason or "").upper(), "stop"
|
||||
)
|
||||
|
||||
|
||||
def _parse_complete_response(raw: dict[str, Any], model: str) -> LLMResponse:
|
||||
"""Parse a non-streaming Antigravity response dict → LLMResponse."""
|
||||
payload: dict[str, Any] = raw.get("response", raw)
|
||||
candidates: list[dict[str, Any]] = payload.get("candidates", [])
|
||||
usage: dict[str, Any] = payload.get("usageMetadata", {})
|
||||
|
||||
text_parts: list[str] = []
|
||||
if candidates:
|
||||
for part in candidates[0].get("content", {}).get("parts", []):
|
||||
if "text" in part and not part.get("thought"):
|
||||
text_parts.append(part["text"])
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(text_parts),
|
||||
model=payload.get("modelVersion", model),
|
||||
input_tokens=usage.get("promptTokenCount", 0),
|
||||
output_tokens=usage.get("candidatesTokenCount", 0),
|
||||
stop_reason=_map_finish_reason(candidates[0].get("finishReason", "") if candidates else ""),
|
||||
raw_response=raw,
|
||||
)
|
||||
|
||||
|
||||
def _parse_sse_stream(
|
||||
response: Any,
|
||||
model: str,
|
||||
on_thought_signature: Callable[[str, str], None] | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Parse Antigravity SSE response line-by-line → StreamEvents.
|
||||
|
||||
Each SSE line looks like::
|
||||
|
||||
data: {"response": {"candidates": [...], "usageMetadata": {...}}, "traceId": "..."}
|
||||
"""
|
||||
accumulated = ""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = ""
|
||||
|
||||
for raw_line in response:
|
||||
line: str = raw_line.decode("utf-8", errors="replace").rstrip("\r\n")
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data_str = line[5:].strip()
|
||||
if not data_str or data_str == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
data: dict[str, Any] = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# The outer envelope is {"response": {...}, "traceId": "..."}.
|
||||
payload: dict[str, Any] = data.get("response", data)
|
||||
|
||||
usage = payload.get("usageMetadata", {})
|
||||
if usage:
|
||||
input_tokens = usage.get("promptTokenCount", input_tokens)
|
||||
output_tokens = usage.get("candidatesTokenCount", output_tokens)
|
||||
|
||||
for candidate in payload.get("candidates", []):
|
||||
fr = candidate.get("finishReason", "")
|
||||
if fr:
|
||||
finish_reason = fr
|
||||
|
||||
for part in candidate.get("content", {}).get("parts", []):
|
||||
if "text" in part and not part.get("thought"):
|
||||
delta: str = part["text"]
|
||||
accumulated += delta
|
||||
yield TextDeltaEvent(content=delta, snapshot=accumulated)
|
||||
elif "functionCall" in part:
|
||||
fc: dict[str, Any] = part["functionCall"]
|
||||
tool_use_id = fc.get("id") or str(uuid.uuid4())
|
||||
thought_sig = part.get("thoughtSignature", "") # sibling of functionCall
|
||||
if thought_sig and on_thought_signature:
|
||||
on_thought_signature(tool_use_id, thought_sig)
|
||||
args = fc.get("args", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=fc.get("name", ""),
|
||||
tool_input=args,
|
||||
)
|
||||
|
||||
if accumulated:
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(
|
||||
stop_reason=_map_finish_reason(finish_reason),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AntigravityProvider(LLMProvider):
|
||||
"""LLM provider for Google's internal Antigravity Code Assist gateway.
|
||||
|
||||
No local proxy required. Handles OAuth token refresh, Gemini-format
|
||||
request/response conversion, and SSE streaming directly.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "gemini-3-flash") -> None:
|
||||
# Strip any provider prefix ("openai/gemini-3-flash" → "gemini-3-flash").
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
self.model = model
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._refresh_token: str | None = None
|
||||
self._project_id: str = _DEFAULT_PROJECT_ID
|
||||
self._token_expires_at: float = 0.0
|
||||
self._thought_sigs: dict[str, str] = {} # tool_use_id → thoughtSignature
|
||||
|
||||
self._init_credentials()
|
||||
|
||||
# --- Credential management -------------------------------------------- #
|
||||
|
||||
def _init_credentials(self) -> None:
|
||||
"""Load credentials from the best available source."""
|
||||
access, refresh, project_id, expires_at = _load_from_json_file()
|
||||
if refresh:
|
||||
self._refresh_token = refresh
|
||||
self._project_id = project_id
|
||||
self._access_token = access
|
||||
self._token_expires_at = expires_at
|
||||
return
|
||||
|
||||
# Fall back to IDE state DB.
|
||||
access, refresh, expires_at = _load_from_ide_db()
|
||||
if access:
|
||||
self._access_token = access
|
||||
self._refresh_token = refresh
|
||||
self._token_expires_at = expires_at
|
||||
|
||||
def has_credentials(self) -> bool:
|
||||
"""Return True if any credential is available."""
|
||||
return bool(self._access_token or self._refresh_token)
|
||||
|
||||
def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing via OAuth if needed."""
|
||||
if (
|
||||
self._access_token
|
||||
and self._token_expires_at
|
||||
and time.time() < self._token_expires_at - _TOKEN_REFRESH_BUFFER_SECS
|
||||
):
|
||||
return self._access_token
|
||||
|
||||
if self._refresh_token:
|
||||
result = _do_token_refresh(self._refresh_token)
|
||||
if result:
|
||||
self._access_token, self._token_expires_at = result
|
||||
return self._access_token
|
||||
|
||||
if self._access_token:
|
||||
logger.warning("Using potentially stale Antigravity access token")
|
||||
return self._access_token
|
||||
|
||||
raise RuntimeError(
|
||||
"No valid Antigravity credentials. "
|
||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
||||
)
|
||||
|
||||
# --- Request building -------------------------------------------------- #
|
||||
|
||||
def _build_body(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool] | None,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
contents = _to_gemini_contents(messages, self._thought_sigs)
|
||||
inner: dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"maxOutputTokens": max_tokens},
|
||||
}
|
||||
if system:
|
||||
inner["systemInstruction"] = {"parts": [{"text": system}]}
|
||||
if tools:
|
||||
inner["tools"] = [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": _clean_tool_name(t.name),
|
||||
"description": t.description,
|
||||
"parameters": t.parameters
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
}
|
||||
]
|
||||
return {
|
||||
"project": self._project_id,
|
||||
"model": self.model,
|
||||
"request": inner,
|
||||
"requestType": "agent",
|
||||
"userAgent": "antigravity",
|
||||
"requestId": f"agent-{uuid.uuid4()}",
|
||||
}
|
||||
|
||||
# --- HTTP transport ---------------------------------------------------- #
|
||||
|
||||
def _post(self, body: dict[str, Any], *, streaming: bool) -> Any:
|
||||
"""POST to the Antigravity endpoint, falling back through the endpoint list."""
|
||||
import urllib.error # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
token = self._ensure_token()
|
||||
body_bytes = json.dumps(body).encode("utf-8")
|
||||
path = (
|
||||
"/v1internal:streamGenerateContent?alt=sse"
|
||||
if streaming
|
||||
else "/v1internal:generateContent"
|
||||
)
|
||||
headers = {
|
||||
**_BASE_HEADERS,
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if streaming:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for base_url in _ENDPOINTS:
|
||||
url = f"{base_url}{path}"
|
||||
req = urllib.request.Request(url, data=body_bytes, headers=headers, method="POST")
|
||||
try:
|
||||
return urllib.request.urlopen(req, timeout=120) # noqa: S310
|
||||
except urllib.error.HTTPError as exc:
|
||||
if exc.code in (401, 403) and self._refresh_token:
|
||||
# Token rejected — refresh once and retry this endpoint.
|
||||
result = _do_token_refresh(self._refresh_token)
|
||||
if result:
|
||||
self._access_token, self._token_expires_at = result
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
req2 = urllib.request.Request(
|
||||
url, data=body_bytes, headers=headers, method="POST"
|
||||
)
|
||||
try:
|
||||
return urllib.request.urlopen(req2, timeout=120) # noqa: S310
|
||||
except urllib.error.HTTPError as exc2:
|
||||
last_exc = exc2
|
||||
continue
|
||||
last_exc = exc
|
||||
continue
|
||||
elif exc.code >= 500:
|
||||
last_exc = exc
|
||||
continue
|
||||
# Include the API response body in the exception for easier debugging.
|
||||
try:
|
||||
err_body = exc.read().decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
err_body = "(unreadable)"
|
||||
raise RuntimeError(f"Antigravity HTTP {exc.code} from {url}: {err_body}") from exc
|
||||
except (urllib.error.URLError, OSError) as exc:
|
||||
last_exc = exc
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f"All Antigravity endpoints failed. Last error: {last_exc}"
|
||||
) from last_exc
|
||||
|
||||
# --- LLMProvider interface --------------------------------------------- #
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
if json_mode:
|
||||
suffix = "\n\nPlease respond with a valid JSON object."
|
||||
system = (system + suffix) if system else suffix.strip()
|
||||
|
||||
body = self._build_body(messages, system, tools, max_tokens)
|
||||
resp = self._post(body, streaming=False)
|
||||
return _parse_complete_response(json.loads(resp.read()), self.model)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
import asyncio # noqa: PLC0415
|
||||
import concurrent.futures # noqa: PLC0415
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
|
||||
|
||||
def _blocking_work() -> None:
|
||||
try:
|
||||
body = self._build_body(messages, system, tools, max_tokens)
|
||||
http_resp = self._post(body, streaming=True)
|
||||
for event in _parse_sse_stream(
|
||||
http_resp, self.model, self._thought_sigs.__setitem__
|
||||
):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, event)
|
||||
except Exception as exc:
|
||||
logger.error("Antigravity stream error: %s", exc)
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StreamErrorEvent(error=str(exc)))
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
fut = loop.run_in_executor(executor, _blocking_work)
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
await fut
|
||||
executor.shutdown(wait=False)
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Model capability checks for LLM providers.
|
||||
|
||||
Vision support rules are derived from official vendor documentation:
|
||||
- ZAI (z.ai): docs.z.ai/guides/vlm — GLM-4.6V variants are vision; GLM-5/4.6/4.7 are text-only
|
||||
- MiniMax: platform.minimax.io/docs — minimax-vl-01 is vision; M2.x are text-only
|
||||
- DeepSeek: api-docs.deepseek.com — deepseek-vl2 is vision; chat/reasoner are text-only
|
||||
- Cerebras: inference-docs.cerebras.ai — no vision models at all
|
||||
- Groq: console.groq.com/docs/vision — vision capable; treat as supported by default
|
||||
- Ollama/LM Studio/vLLM/llama.cpp: local runners denied by default; model names
|
||||
don't reliably indicate vision support, so users must configure explicitly
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _model_name(model: str) -> str:
|
||||
"""Return the bare model name after stripping any 'provider/' prefix."""
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
# Step 1: explicit vision allow-list — these always support images regardless
|
||||
# of what the provider-level rules say. Checked first so that e.g. glm-4.6v
|
||||
# is allowed even though glm-4.6 is denied.
|
||||
_VISION_ALLOW_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# ZAI/GLM vision models (docs.z.ai/guides/vlm)
|
||||
"glm-4v", # GLM-4V series (legacy)
|
||||
"glm-4.6v", # GLM-4.6V, GLM-4.6V-flash, GLM-4.6V-flashx
|
||||
# DeepSeek vision models
|
||||
"deepseek-vl", # deepseek-vl2, deepseek-vl2-small, deepseek-vl2-tiny
|
||||
# MiniMax vision model
|
||||
"minimax-vl", # minimax-vl-01
|
||||
)
|
||||
|
||||
# Step 2: provider-level deny — every model from this provider is text-only.
|
||||
_TEXT_ONLY_PROVIDER_PREFIXES: tuple[str, ...] = (
|
||||
# Cerebras: inference-docs.cerebras.ai lists only text models
|
||||
"cerebras/",
|
||||
# Local runners: model names don't reliably indicate vision support
|
||||
"ollama/",
|
||||
"ollama_chat/",
|
||||
"lm_studio/",
|
||||
"vllm/",
|
||||
"llamacpp/",
|
||||
)
|
||||
|
||||
# Step 3: per-model deny — text-only models within otherwise mixed providers.
|
||||
# Matched against the bare model name (provider prefix stripped, lower-cased).
|
||||
# The vision allow-list above is checked first, so vision variants of the same
|
||||
# family are already handled before these deny patterns are reached.
|
||||
_TEXT_ONLY_MODEL_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# --- ZAI / GLM family ---
|
||||
# text-only: glm-5, glm-4.6, glm-4.7, glm-4.5, zai-glm-*
|
||||
# vision: glm-4v, glm-4.6v (caught by allow-list above)
|
||||
"glm-5",
|
||||
"glm-4.6", # bare glm-4.6 is text-only; glm-4.6v is caught by allow-list
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"zai-glm",
|
||||
# --- DeepSeek ---
|
||||
# text-only: deepseek-chat, deepseek-coder, deepseek-reasoner
|
||||
# vision: deepseek-vl2 (caught by allow-list above)
|
||||
# Note: LiteLLM's deepseek handler may flatten content lists for some models;
|
||||
# VL models are allowed through and rely on LiteLLM's native VL support.
|
||||
"deepseek-chat",
|
||||
"deepseek-coder",
|
||||
"deepseek-reasoner",
|
||||
# --- MiniMax ---
|
||||
# text-only: minimax-m2.*, minimax-text-*, abab* (legacy)
|
||||
# vision: minimax-vl-01 (caught by allow-list above)
|
||||
"minimax-m2",
|
||||
"minimax-text",
|
||||
"abab",
|
||||
)
|
||||
|
||||
|
||||
def supports_image_tool_results(model: str) -> bool:
|
||||
"""Return whether *model* can receive image content in messages.
|
||||
|
||||
Used to gate both user-message images and tool-result image blocks.
|
||||
|
||||
Logic (checked in order):
|
||||
1. Vision allow-list → True (known vision model, skip all denies)
|
||||
2. Provider deny → False (entire provider is text-only)
|
||||
3. Model deny → False (specific text-only model within a mixed provider)
|
||||
4. Default → True (assume capable; unknown providers and models)
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
bare = _model_name(model_lower)
|
||||
|
||||
# 1. Explicit vision allow — takes priority over all denies
|
||||
if any(bare.startswith(p) for p in _VISION_ALLOW_BARE_PREFIXES):
|
||||
return True
|
||||
|
||||
# 2. Provider-level deny (all models from this provider are text-only)
|
||||
if any(model_lower.startswith(p) for p in _TEXT_ONLY_PROVIDER_PREFIXES):
|
||||
return False
|
||||
|
||||
# 3. Per-model deny (text-only variants within mixed-capability families)
|
||||
if any(bare.startswith(p) for p in _TEXT_ONLY_MODEL_BARE_PREFIXES):
|
||||
return False
|
||||
|
||||
# 5. Default: assume vision capable
|
||||
# Covers: OpenAI, Anthropic, Google, Mistral, Kimi, and other hosted providers
|
||||
return True
|
||||
+684
-13
@@ -7,9 +7,13 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
@@ -44,7 +48,10 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
"""
|
||||
try:
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_OAUTH_TOKEN_PREFIX
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_OAUTH_BETA_HEADER,
|
||||
ANTHROPIC_OAUTH_TOKEN_PREFIX,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Could not apply litellm Anthropic OAuth patch — litellm internals may have "
|
||||
@@ -69,9 +76,27 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
# Check both authorization header and x-api-key for OAuth tokens.
|
||||
# litellm's optionally_handle_anthropic_oauth only checks headers["authorization"],
|
||||
# but hive passes OAuth tokens via api_key — so litellm puts them into x-api-key.
|
||||
# Anthropic rejects OAuth tokens in x-api-key; they must go in Authorization: Bearer.
|
||||
auth = result.get("authorization", "")
|
||||
if auth.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
|
||||
x_api_key = result.get("x-api-key", "")
|
||||
oauth_prefix = f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"
|
||||
auth_is_oauth = auth.startswith(oauth_prefix)
|
||||
key_is_oauth = x_api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
|
||||
if auth_is_oauth or key_is_oauth:
|
||||
token = x_api_key if key_is_oauth else auth.removeprefix("Bearer ").strip()
|
||||
result.pop("x-api-key", None)
|
||||
result["authorization"] = f"Bearer {token}"
|
||||
# Merge the OAuth beta header with any existing beta headers.
|
||||
existing_beta = result.get("anthropic-beta", "")
|
||||
beta_parts = (
|
||||
[b.strip() for b in existing_beta.split(",") if b.strip()] if existing_beta else []
|
||||
)
|
||||
if ANTHROPIC_OAUTH_BETA_HEADER not in beta_parts:
|
||||
beta_parts.append(ANTHROPIC_OAUTH_BETA_HEADER)
|
||||
result["anthropic-beta"] = ",".join(beta_parts)
|
||||
return result
|
||||
|
||||
AnthropicModelInfo.validate_environment = _patched_validate_environment
|
||||
@@ -130,11 +155,35 @@ def _patch_litellm_metadata_nonetype() -> None:
|
||||
if litellm is not None:
|
||||
_patch_litellm_anthropic_oauth()
|
||||
_patch_litellm_metadata_nonetype()
|
||||
# Let litellm silently drop params unsupported by the target provider
|
||||
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
|
||||
litellm.drop_params = True
|
||||
|
||||
|
||||
def _is_ollama_model(model: str) -> bool:
|
||||
"""Return True for any Ollama model string (ollama/ or ollama_chat/ prefix)."""
|
||||
return model.startswith("ollama/") or model.startswith("ollama_chat/")
|
||||
|
||||
|
||||
def _ensure_ollama_chat_prefix(model: str) -> str:
|
||||
"""Normalise Ollama model strings to use the ollama_chat/ prefix.
|
||||
|
||||
LiteLLM requires the ``ollama_chat/`` prefix (not ``ollama/``) to enable
|
||||
native function-calling support. With ``ollama/``, LiteLLM falls back to
|
||||
JSON-mode tool calls, which the framework cannot parse as real tool calls.
|
||||
|
||||
See: https://docs.litellm.ai/docs/providers/ollama#example-usage---tool-calling
|
||||
"""
|
||||
if model.startswith("ollama/"):
|
||||
return "ollama_chat/" + model[len("ollama/") :]
|
||||
return model
|
||||
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
MINIMAX_API_BASE = "https://api.minimax.io/v1"
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
# Providers that accept cache_control on message content blocks.
|
||||
# Anthropic: native ephemeral caching. MiniMax & Z-AI/GLM: pass-through to their APIs.
|
||||
@@ -159,10 +208,69 @@ def _model_supports_cache_control(model: str) -> bool:
|
||||
# enforces a coding-agent whitelist that blocks unknown User-Agents.
|
||||
KIMI_API_BASE = "https://api.kimi.com/coding"
|
||||
|
||||
# Claude Code OAuth subscription: the Anthropic API requires a specific
|
||||
# User-Agent and a billing integrity header for OAuth-authenticated requests.
|
||||
CLAUDE_CODE_VERSION = "2.1.76"
|
||||
CLAUDE_CODE_USER_AGENT = f"claude-code/{CLAUDE_CODE_VERSION}"
|
||||
_CLAUDE_CODE_BILLING_SALT = "59cf53e54c78"
|
||||
|
||||
|
||||
def _sample_js_code_unit(text: str, idx: int) -> str:
|
||||
"""Return the character at UTF-16 code unit index *idx*, matching JS semantics."""
|
||||
encoded = text.encode("utf-16-le")
|
||||
unit_offset = idx * 2
|
||||
if unit_offset + 2 > len(encoded):
|
||||
return "0"
|
||||
code_unit = int.from_bytes(encoded[unit_offset : unit_offset + 2], "little")
|
||||
return chr(code_unit)
|
||||
|
||||
|
||||
def _claude_code_billing_header(messages: list[dict[str, Any]]) -> str:
|
||||
"""Build the billing integrity system block required by Anthropic's OAuth path."""
|
||||
# Find the first user message text
|
||||
first_text = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
first_text = content
|
||||
break
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text" and block.get("text"):
|
||||
first_text = block["text"]
|
||||
break
|
||||
if first_text:
|
||||
break
|
||||
|
||||
sampled = "".join(_sample_js_code_unit(first_text, i) for i in (4, 7, 20))
|
||||
version_hash = hashlib.sha256(
|
||||
f"{_CLAUDE_CODE_BILLING_SALT}{sampled}{CLAUDE_CODE_VERSION}".encode()
|
||||
).hexdigest()
|
||||
entrypoint = os.environ.get("CLAUDE_CODE_ENTRYPOINT", "").strip() or "cli"
|
||||
return (
|
||||
f"x-anthropic-billing-header: cc_version={CLAUDE_CODE_VERSION}.{version_hash[:3]}; "
|
||||
f"cc_entrypoint={entrypoint}; cch=00000;"
|
||||
)
|
||||
|
||||
|
||||
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
|
||||
# Conversation-structure issues are deterministic — long waits don't help.
|
||||
EMPTY_STREAM_MAX_RETRIES = 3
|
||||
EMPTY_STREAM_RETRY_DELAY = 1.0 # seconds
|
||||
OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS = (
|
||||
"no endpoints found that support tool use",
|
||||
"no endpoints available that support tool use",
|
||||
"provider routing",
|
||||
)
|
||||
OPENROUTER_TOOL_CALL_RE = re.compile(
|
||||
r"<\|tool_call_start\|>\s*(.*?)\s*<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS = 3600
|
||||
# OpenRouter routing can change over time, so tool-compat caching must expire.
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE: dict[str, float] = {}
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
@@ -205,6 +313,24 @@ def _prune_failed_request_dumps(max_files: int = MAX_FAILED_REQUEST_DUMPS) -> No
|
||||
pass # Best-effort — never block the caller
|
||||
|
||||
|
||||
def _remember_openrouter_tool_compat_model(model: str) -> None:
|
||||
"""Cache OpenRouter tool-compat fallback for a bounded time window."""
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE[model] = (
|
||||
time.monotonic() + OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS
|
||||
)
|
||||
|
||||
|
||||
def _is_openrouter_tool_compat_cached(model: str) -> bool:
|
||||
"""Return True when the cached OpenRouter compat entry is still fresh."""
|
||||
expires_at = OPENROUTER_TOOL_COMPAT_MODEL_CACHE.get(model)
|
||||
if expires_at is None:
|
||||
return False
|
||||
if expires_at <= time.monotonic():
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.pop(model, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _dump_failed_request(
|
||||
model: str,
|
||||
kwargs: dict[str, Any],
|
||||
@@ -393,7 +519,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Translate kimi/ prefix to anthropic/ so litellm uses the Anthropic
|
||||
# Messages API handler and routes to that endpoint — no special headers needed.
|
||||
_original_model = model
|
||||
if model.lower().startswith("kimi/"):
|
||||
if _is_ollama_model(model):
|
||||
model = _ensure_ollama_chat_prefix(model)
|
||||
elif model.lower().startswith("kimi/"):
|
||||
model = "anthropic/" + model[len("kimi/") :]
|
||||
# Normalise api_base: litellm's Anthropic handler appends /v1/messages,
|
||||
# so the base must be https://api.kimi.com/coding (no /v1 suffix).
|
||||
@@ -408,11 +536,19 @@ class LiteLLMProvider(LLMProvider):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or self._default_api_base_for_model(_original_model)
|
||||
self.extra_kwargs = kwargs
|
||||
# Detect Claude Code OAuth subscription by checking the api_key prefix.
|
||||
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
|
||||
if self._claude_code_oauth:
|
||||
# Anthropic requires a specific User-Agent for OAuth requests.
|
||||
eh = self.extra_kwargs.setdefault("extra_headers", {})
|
||||
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
|
||||
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
|
||||
# several standard OpenAI params: max_output_tokens, stream_options.
|
||||
self._codex_backend = bool(
|
||||
self.api_base and "chatgpt.com/backend-api/codex" in self.api_base
|
||||
)
|
||||
# Antigravity routes through a local OpenAI-compatible proxy — no patches needed.
|
||||
self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base)
|
||||
|
||||
if litellm is None:
|
||||
raise ImportError(
|
||||
@@ -431,6 +567,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
model_lower = model.lower()
|
||||
if model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
|
||||
return MINIMAX_API_BASE
|
||||
if model_lower.startswith("openrouter/"):
|
||||
return OPENROUTER_API_BASE
|
||||
if model_lower.startswith("kimi/"):
|
||||
return KIMI_API_BASE
|
||||
if model_lower.startswith("hive/"):
|
||||
@@ -606,6 +744,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
|
||||
# Add response_format for structured output
|
||||
# LiteLLM passes this through to the underlying provider
|
||||
@@ -773,6 +915,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -800,6 +945,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
@@ -834,11 +983,504 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
}
|
||||
|
||||
def _is_anthropic_model(self) -> bool:
|
||||
"""Return True when the configured model targets Anthropic."""
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("anthropic/") or model.startswith("claude-")
|
||||
|
||||
def _is_minimax_model(self) -> bool:
|
||||
"""Return True when the configured model targets MiniMax."""
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("minimax/") or model.startswith("minimax-")
|
||||
|
||||
def _is_openrouter_model(self) -> bool:
|
||||
"""Return True when the configured model targets OpenRouter."""
|
||||
model = (self.model or "").lower()
|
||||
if model.startswith("openrouter/"):
|
||||
return True
|
||||
api_base = (self.api_base or "").lower()
|
||||
return "openrouter.ai/api/v1" in api_base
|
||||
|
||||
def _should_use_openrouter_tool_compat(
|
||||
self,
|
||||
error: BaseException,
|
||||
tools: list[Tool] | None,
|
||||
) -> bool:
|
||||
"""Return True when OpenRouter rejects native tool use for the model."""
|
||||
if not tools or not self._is_openrouter_model():
|
||||
return False
|
||||
error_text = str(error).lower()
|
||||
return "openrouter" in error_text and any(
|
||||
snippet in error_text for snippet in OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||
"""Extract the first JSON object from a model response."""
|
||||
candidates = [text.strip()]
|
||||
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
fence_lines = stripped.splitlines()
|
||||
if len(fence_lines) >= 3:
|
||||
candidates.append("\n".join(fence_lines[1:-1]).strip())
|
||||
|
||||
decoder = json.JSONDecoder()
|
||||
for candidate in candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
for start_idx, char in enumerate(candidate):
|
||||
if char != "{":
|
||||
continue
|
||||
try:
|
||||
parsed, _ = decoder.raw_decode(candidate[start_idx:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_openrouter_tool_compat_response(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse JSON tool-compat output into assistant text and tool calls."""
|
||||
payload = self._extract_json_object(content)
|
||||
if payload is None:
|
||||
text_tool_content, text_tool_calls = self._parse_openrouter_text_tool_calls(
|
||||
content,
|
||||
tools,
|
||||
)
|
||||
if text_tool_calls:
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Parsed textual tool-call markers for %s",
|
||||
self.model,
|
||||
)
|
||||
return text_tool_content, text_tool_calls
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] %s returned non-JSON fallback content; "
|
||||
"treating it as plain text.",
|
||||
self.model,
|
||||
)
|
||||
return content.strip(), []
|
||||
|
||||
assistant_text = payload.get("assistant_response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("content")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = ""
|
||||
|
||||
tool_calls_raw = payload.get("tool_calls")
|
||||
if not tool_calls_raw and {"name", "arguments"} <= payload.keys():
|
||||
tool_calls_raw = [payload]
|
||||
elif isinstance(payload.get("tool_call"), dict):
|
||||
tool_calls_raw = [payload["tool_call"]]
|
||||
|
||||
if not isinstance(tool_calls_raw, list):
|
||||
tool_calls_raw = []
|
||||
|
||||
allowed_tool_names = {tool.name for tool in tools}
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
|
||||
for idx, raw_call in enumerate(tool_calls_raw):
|
||||
if not isinstance(raw_call, dict):
|
||||
continue
|
||||
|
||||
function_block = raw_call.get("function")
|
||||
function_name = (
|
||||
raw_call.get("name")
|
||||
or raw_call.get("tool_name")
|
||||
or (function_block.get("name") if isinstance(function_block, dict) else None)
|
||||
)
|
||||
if not isinstance(function_name, str) or function_name not in allowed_tool_names:
|
||||
if function_name:
|
||||
logger.warning(
|
||||
"[openrouter-tool-compat] Ignoring unknown tool '%s' for model %s",
|
||||
function_name,
|
||||
self.model,
|
||||
)
|
||||
continue
|
||||
|
||||
arguments = raw_call.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("tool_input")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("input")
|
||||
if arguments is None and isinstance(function_block, dict):
|
||||
arguments = function_block.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"_raw": arguments}
|
||||
elif not isinstance(arguments, dict):
|
||||
arguments = {"value": arguments}
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{idx}",
|
||||
"name": function_name,
|
||||
"input": arguments,
|
||||
}
|
||||
)
|
||||
|
||||
return assistant_text.strip(), tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _close_truncated_json_fragment(fragment: str) -> str:
|
||||
"""Close a truncated JSON fragment by balancing quotes/brackets."""
|
||||
stack: list[str] = []
|
||||
in_string = False
|
||||
escaped = False
|
||||
normalized = fragment.rstrip()
|
||||
|
||||
while normalized and normalized[-1] in ",:{[":
|
||||
normalized = normalized[:-1].rstrip()
|
||||
|
||||
for char in normalized:
|
||||
if in_string:
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif char == "\\":
|
||||
escaped = True
|
||||
elif char == '"':
|
||||
in_string = False
|
||||
continue
|
||||
|
||||
if char == '"':
|
||||
in_string = True
|
||||
elif char in "{[":
|
||||
stack.append(char)
|
||||
elif char == "}" and stack and stack[-1] == "{":
|
||||
stack.pop()
|
||||
elif char == "]" and stack and stack[-1] == "[":
|
||||
stack.pop()
|
||||
|
||||
if in_string:
|
||||
if escaped:
|
||||
normalized = normalized[:-1]
|
||||
normalized += '"'
|
||||
|
||||
for opener in reversed(stack):
|
||||
normalized += "}" if opener == "{" else "]"
|
||||
|
||||
return normalized
|
||||
|
||||
def _repair_truncated_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None:
|
||||
"""Try to recover a truncated JSON object from tool-call arguments."""
|
||||
stripped = raw_arguments.strip()
|
||||
if not stripped or stripped[0] != "{":
|
||||
return None
|
||||
|
||||
max_trim = min(len(stripped), 256)
|
||||
for trim in range(max_trim + 1):
|
||||
candidate = stripped[: len(stripped) - trim].rstrip()
|
||||
if not candidate:
|
||||
break
|
||||
candidate = self._close_truncated_json_fragment(candidate)
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]:
|
||||
"""Parse streamed tool arguments, repairing truncation when possible."""
|
||||
try:
|
||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
repaired = self._repair_truncated_tool_arguments(raw_arguments)
|
||||
if repaired is not None:
|
||||
logger.warning(
|
||||
"[tool-args] Recovered truncated arguments for %s on %s",
|
||||
tool_name,
|
||||
self.model,
|
||||
)
|
||||
return repaired
|
||||
|
||||
raise ValueError(
|
||||
f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)."
|
||||
)
|
||||
|
||||
def _parse_openrouter_text_tool_calls(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse textual OpenRouter tool calls into synthetic tool calls.
|
||||
|
||||
Supports both:
|
||||
- Marker wrapped payloads: <|tool_call_start|>...<|tool_call_end|>
|
||||
- Plain one-line tool calls: ask_user("...", ["..."])
|
||||
"""
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
segment_index = 0
|
||||
|
||||
for match in OPENROUTER_TOOL_CALL_RE.finditer(content):
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=match.group(1),
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
|
||||
stripped_content = OPENROUTER_TOOL_CALL_RE.sub("", content)
|
||||
retained_lines: list[str] = []
|
||||
for line in stripped_content.splitlines():
|
||||
stripped_line = line.strip()
|
||||
if not stripped_line:
|
||||
retained_lines.append(line)
|
||||
continue
|
||||
|
||||
candidate = stripped_line
|
||||
if candidate.startswith("`") and candidate.endswith("`") and len(candidate) > 1:
|
||||
candidate = candidate[1:-1].strip()
|
||||
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=candidate,
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
continue
|
||||
|
||||
retained_lines.append(line)
|
||||
|
||||
stripped_text = "\n".join(retained_lines).strip()
|
||||
return stripped_text, tool_calls
|
||||
|
||||
def _parse_openrouter_text_tool_call_block(
|
||||
self,
|
||||
block: str,
|
||||
tools_by_name: dict[str, Tool],
|
||||
compat_prefix: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse a single textual tool-call block like [tool(arg='x')]."""
|
||||
try:
|
||||
parsed = ast.parse(block.strip(), mode="eval").body
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
call_nodes = parsed.elts if isinstance(parsed, ast.List) else [parsed]
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for call_index, call_node in enumerate(call_nodes):
|
||||
if not isinstance(call_node, ast.Call) or not isinstance(call_node.func, ast.Name):
|
||||
continue
|
||||
|
||||
tool_name = call_node.func.id
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_input = self._parse_openrouter_text_tool_call_arguments(
|
||||
call_node=call_node,
|
||||
tool=tool,
|
||||
)
|
||||
except (ValueError, SyntaxError):
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{call_index}",
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _parse_openrouter_text_tool_call_arguments(
|
||||
call_node: ast.Call,
|
||||
tool: Tool,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse positional/keyword args from a textual tool call."""
|
||||
properties = tool.parameters.get("properties", {})
|
||||
positional_keys = list(properties.keys())
|
||||
tool_input: dict[str, Any] = {}
|
||||
|
||||
if len(call_node.args) > len(positional_keys):
|
||||
raise ValueError("Too many positional args for textual tool call")
|
||||
|
||||
for idx, arg_node in enumerate(call_node.args):
|
||||
tool_input[positional_keys[idx]] = ast.literal_eval(arg_node)
|
||||
|
||||
for kwarg in call_node.keywords:
|
||||
if kwarg.arg is None:
|
||||
raise ValueError("Star args are not supported in textual tool calls")
|
||||
tool_input[kwarg.arg] = ast.literal_eval(kwarg.value)
|
||||
|
||||
return tool_input
|
||||
|
||||
def _build_openrouter_tool_compat_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a JSON-only prompt for models without native tool support."""
|
||||
tool_specs = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
compat_instruction = (
|
||||
"Tool compatibility mode is active because this OpenRouter model does not support "
|
||||
"native function calling on the routed provider.\n"
|
||||
"Return exactly one JSON object and nothing else.\n"
|
||||
'Schema: {"assistant_response": string, '
|
||||
'"tool_calls": [{"name": string, "arguments": object}]}\n'
|
||||
"Rules:\n"
|
||||
"- If a tool is required, put one or more entries in tool_calls "
|
||||
"and do not invent tool results.\n"
|
||||
"- If no tool is required, set tool_calls to [] and put the full "
|
||||
"answer in assistant_response.\n"
|
||||
"- Only use tool names from the allowed tool list.\n"
|
||||
"- arguments must always be valid JSON objects.\n"
|
||||
f"Allowed tools:\n{json.dumps(tool_specs, ensure_ascii=True)}"
|
||||
)
|
||||
compat_system = compat_instruction if not system else f"{system}\n\n{compat_instruction}"
|
||||
|
||||
full_messages: list[dict[str, Any]] = [{"role": "system", "content": compat_system}]
|
||||
full_messages.extend(messages)
|
||||
return [
|
||||
message
|
||||
for message in full_messages
|
||||
if not (
|
||||
message.get("role") == "assistant"
|
||||
and not message.get("content")
|
||||
and not message.get("tool_calls")
|
||||
)
|
||||
]
|
||||
|
||||
async def _acomplete_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools."""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(messages, system, tools)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
assistant_text, tool_calls = self._parse_openrouter_tool_compat_response(
|
||||
raw_content,
|
||||
tools,
|
||||
)
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
stop_reason = "tool_calls" if tool_calls else (response.choices[0].finish_reason or "stop")
|
||||
|
||||
return LLMResponse(
|
||||
content=assistant_text,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={
|
||||
"compat_mode": "openrouter_tool_emulation",
|
||||
"tool_calls": tool_calls,
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
|
||||
async def _stream_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback stream for OpenRouter models without native tool support."""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Using compatibility mode for %s",
|
||||
self.model,
|
||||
)
|
||||
try:
|
||||
response = await self._acomplete_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
raw_response = response.raw_response if isinstance(response.raw_response, dict) else {}
|
||||
tool_calls = raw_response.get("tool_calls", [])
|
||||
|
||||
if response.content:
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=tool_call["id"],
|
||||
tool_name=tool_call["name"],
|
||||
tool_input=tool_call["input"],
|
||||
)
|
||||
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _stream_via_nonstream_completion(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -882,12 +1524,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
tool_calls = msg.tool_calls or []
|
||||
|
||||
for tc in tool_calls:
|
||||
parsed_args: Any
|
||||
args = tc.function.arguments if tc.function else ""
|
||||
try:
|
||||
parsed_args = json.loads(args) if args else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed_args = {"_raw": args}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
args,
|
||||
tc.function.name if tc.function else "",
|
||||
)
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=getattr(tc, "id", ""),
|
||||
tool_name=tc.function.name if tc.function else "",
|
||||
@@ -946,7 +1587,20 @@ class LiteLLMProvider(LLMProvider):
|
||||
yield event
|
||||
return
|
||||
|
||||
if tools and self._is_openrouter_model() and _is_openrouter_tool_compat_cached(self.model):
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -984,15 +1638,22 @@ class LiteLLMProvider(LLMProvider):
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
|
||||
# Only include it for providers that support it.
|
||||
if not self._is_anthropic_model():
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
# The Codex ChatGPT backend (Responses API) rejects several params.
|
||||
@@ -1092,10 +1753,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
if choice.finish_reason:
|
||||
stream_finish_reason = choice.finish_reason
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
tc_data.get("arguments", ""),
|
||||
tc_data.get("name", ""),
|
||||
)
|
||||
tail_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
@@ -1276,6 +1937,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if self._should_use_openrouter_tool_compat(e, tools):
|
||||
_remember_openrouter_tool_compat_model(self.model)
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools or [],
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
if _is_stream_transient_error(e) and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
|
||||
@@ -45,6 +45,8 @@ class ToolResult:
|
||||
tool_use_id: str
|
||||
content: str
|
||||
is_error: bool = False
|
||||
image_content: list[dict[str, Any]] | None = None
|
||||
is_skill_content: bool = False # AS-10: marks activated skill body, protected from pruning
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
|
||||
@@ -30,6 +30,8 @@ from typing import Any
|
||||
# ContextVar is thread-safe and async-safe - perfect for concurrent agent execution
|
||||
trace_context: ContextVar[dict[str, Any] | None] = ContextVar("trace_context", default=None)
|
||||
|
||||
_STANDARD_LOG_RECORD_FIELDS = set(logging.makeLogRecord({}).__dict__)
|
||||
|
||||
# ANSI escape code pattern (matches \033[...m or \x1b[...m)
|
||||
ANSI_ESCAPE_PATTERN = re.compile(r"\x1b\[[0-9;]*m|\033\[[0-9;]*m")
|
||||
|
||||
@@ -92,6 +94,14 @@ class StructuredFormatter(logging.Formatter):
|
||||
if model is not None:
|
||||
log_entry["model"] = model
|
||||
|
||||
# Preserve arbitrary structured fields passed via ``extra=...``.
|
||||
for key, value in record.__dict__.items():
|
||||
if key in _STANDARD_LOG_RECORD_FIELDS or key.startswith("_"):
|
||||
continue
|
||||
if key in log_entry:
|
||||
continue
|
||||
log_entry[key] = value
|
||||
|
||||
# Add exception info if present (strip ANSI codes from exception text too)
|
||||
if record.exc_info:
|
||||
exception_text = self.formatException(record.exc_info)
|
||||
@@ -208,7 +218,12 @@ def configure_logging(
|
||||
|
||||
# Suppress noisy LiteLLM INFO logs (model/provider line + Provider List URL
|
||||
# printed on every single completion call). Warnings and errors still show.
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
# Honour LITELLM_LOG env var so users can opt-in to debug output.
|
||||
_litellm_level = os.getenv("LITELLM_LOG", "").upper()
|
||||
if _litellm_level and hasattr(logging, _litellm_level):
|
||||
logging.getLogger("LiteLLM").setLevel(getattr(logging, _litellm_level))
|
||||
else:
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
|
||||
# When in JSON mode, configure known third-party loggers to use JSON formatter
|
||||
# This ensures libraries like LiteLLM, httpcore also output clean JSON
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Agent Runner - load and run exported agents."""
|
||||
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
from framework.runner.orchestrator import AgentOrchestrator
|
||||
from framework.runner.protocol import (
|
||||
AgentMessage,
|
||||
@@ -17,6 +18,7 @@ __all__ = [
|
||||
"AgentInfo",
|
||||
"ValidationResult",
|
||||
"ToolRegistry",
|
||||
"MCPRegistry",
|
||||
"tool",
|
||||
# Multi-agent
|
||||
"AgentOrchestrator",
|
||||
|
||||
@@ -1561,6 +1561,22 @@ def _open_browser(url: str) -> None:
|
||||
pass # Best-effort — don't crash if browser can't open
|
||||
|
||||
|
||||
def _format_subprocess_output(output: str | bytes | None, limit: int = 2000) -> str:
|
||||
"""Return subprocess output as trimmed text safe for console logging."""
|
||||
if not output:
|
||||
return ""
|
||||
|
||||
if isinstance(output, bytes):
|
||||
text = output.decode(errors="replace")
|
||||
else:
|
||||
text = output
|
||||
|
||||
text = text.strip()
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[-limit:]
|
||||
|
||||
|
||||
def _build_frontend() -> bool:
|
||||
"""Build the frontend if source is newer than dist. Returns True if dist exists."""
|
||||
import subprocess
|
||||
@@ -1596,18 +1612,25 @@ def _build_frontend() -> bool:
|
||||
|
||||
# Need to build
|
||||
print("Building frontend...")
|
||||
npm_cmd = "npm.cmd" if sys.platform == "win32" else "npm"
|
||||
try:
|
||||
# Incremental tsc caches can drift across branch changes and block builds.
|
||||
for cache_file in frontend_dir.glob("tsconfig*.tsbuildinfo"):
|
||||
cache_file.unlink(missing_ok=True)
|
||||
|
||||
# Ensure deps are installed
|
||||
subprocess.run(
|
||||
["npm", "install", "--no-fund", "--no-audit"],
|
||||
[npm_cmd, "install", "--no-fund", "--no-audit"],
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
cwd=frontend_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["npm", "run", "build"],
|
||||
[npm_cmd, "run", "build"],
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
cwd=frontend_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
@@ -1618,8 +1641,14 @@ def _build_frontend() -> bool:
|
||||
print("Node.js not found — skipping frontend build.")
|
||||
return dist_dir.is_dir()
|
||||
except subprocess.CalledProcessError as exc:
|
||||
stderr = exc.stderr.decode(errors="replace") if exc.stderr else ""
|
||||
print(f"Frontend build failed: {stderr[:500]}")
|
||||
stdout = _format_subprocess_output(exc.stdout)
|
||||
stderr = _format_subprocess_output(exc.stderr)
|
||||
cmd = " ".join(exc.cmd) if isinstance(exc.cmd, (list, tuple)) else str(exc.cmd)
|
||||
details = "\n".join(part for part in [stdout, stderr] if part).strip()
|
||||
if details:
|
||||
print(f"Frontend build failed while running {cmd}:\n{details}")
|
||||
else:
|
||||
print(f"Frontend build failed while running {cmd} (exit {exc.returncode}).")
|
||||
return dist_dir.is_dir()
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""MCP Client for connecting to Model Context Protocol servers.
|
||||
|
||||
This module provides a client for connecting to MCP servers and invoking their tools.
|
||||
Supports both STDIO and HTTP transports using the official MCP Python SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP Python SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -14,6 +14,8 @@ from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.runner.mcp_errors import MCPToolNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -22,7 +24,7 @@ class MCPServerConfig:
|
||||
"""Configuration for an MCP server connection."""
|
||||
|
||||
name: str
|
||||
transport: Literal["stdio", "http"]
|
||||
transport: Literal["stdio", "http", "unix", "sse"]
|
||||
|
||||
# For STDIO transport
|
||||
command: str | None = None
|
||||
@@ -33,6 +35,7 @@ class MCPServerConfig:
|
||||
# For HTTP transport
|
||||
url: str | None = None
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
socket_path: str | None = None
|
||||
|
||||
# Optional metadata
|
||||
description: str = ""
|
||||
@@ -52,7 +55,7 @@ class MCPClient:
|
||||
"""
|
||||
Client for communicating with MCP servers.
|
||||
|
||||
Supports both STDIO and HTTP transports using the official MCP SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP SDK.
|
||||
Manages the connection lifecycle and provides methods to list and invoke tools.
|
||||
"""
|
||||
|
||||
@@ -68,6 +71,7 @@ class MCPClient:
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._stdio_context = None # Context manager for stdio_client
|
||||
self._sse_context = None # Context manager for sse_client
|
||||
self._errlog_handle = None # Track errlog file handle for cleanup
|
||||
self._http_client: httpx.Client | None = None
|
||||
self._tools: dict[str, MCPTool] = {}
|
||||
@@ -141,6 +145,10 @@ class MCPClient:
|
||||
self._connect_stdio()
|
||||
elif self.config.transport == "http":
|
||||
self._connect_http()
|
||||
elif self.config.transport == "unix":
|
||||
self._connect_unix()
|
||||
elif self.config.transport == "sse":
|
||||
self._connect_sse()
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport: {self.config.transport}")
|
||||
|
||||
@@ -266,10 +274,94 @@ class MCPClient:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_unix(self) -> None:
|
||||
"""Connect to MCP server via UNIX domain socket transport."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for UNIX transport")
|
||||
if not self.config.socket_path:
|
||||
raise ValueError("socket_path is required for UNIX transport")
|
||||
|
||||
self._http_client = httpx.Client(
|
||||
base_url=self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
transport=httpx.HTTPTransport(uds=self.config.socket_path),
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
"Connected to MCP server '%s' via UNIX socket at %s",
|
||||
self.config.name,
|
||||
self.config.socket_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_sse(self) -> None:
|
||||
"""Connect to MCP server via SSE transport using MCP SDK with persistent session."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for SSE transport")
|
||||
|
||||
try:
|
||||
loop_started = threading.Event()
|
||||
connection_ready = threading.Event()
|
||||
connection_error = []
|
||||
|
||||
def run_event_loop():
|
||||
"""Run event loop in background thread."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
loop_started.set()
|
||||
|
||||
async def init_connection():
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
self._sse_context = sse_client(
|
||||
self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
) = await self._sse_context.__aenter__()
|
||||
|
||||
self._session = ClientSession(self._read_stream, self._write_stream)
|
||||
await self._session.__aenter__()
|
||||
await self._session.initialize()
|
||||
|
||||
connection_ready.set()
|
||||
except Exception as e:
|
||||
connection_error.append(e)
|
||||
connection_ready.set()
|
||||
|
||||
self._loop.create_task(init_connection())
|
||||
self._loop.run_forever()
|
||||
|
||||
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
|
||||
self._loop_thread.start()
|
||||
|
||||
loop_started.wait(timeout=5)
|
||||
if not loop_started.is_set():
|
||||
raise RuntimeError("Event loop failed to start")
|
||||
|
||||
connection_ready.wait(timeout=10)
|
||||
if connection_error:
|
||||
raise connection_error[0]
|
||||
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via SSE")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
def _discover_tools(self) -> None:
|
||||
"""Discover available tools from the MCP server."""
|
||||
try:
|
||||
if self.config.transport == "stdio":
|
||||
if self.config.transport in {"stdio", "sse"}:
|
||||
tools_list = self._run_async(self._list_tools_stdio_async())
|
||||
else:
|
||||
tools_list = self._list_tools_http()
|
||||
@@ -366,14 +458,45 @@ class MCPClient:
|
||||
self.connect()
|
||||
|
||||
if tool_name not in self._tools:
|
||||
raise ValueError(f"Unknown tool: {tool_name}")
|
||||
raise MCPToolNotFoundError(
|
||||
server=self.config.name,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
|
||||
if self.config.transport == "stdio":
|
||||
with self._stdio_call_lock:
|
||||
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
elif self.config.transport == "sse":
|
||||
return self._call_tool_with_retry(
|
||||
lambda: self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
)
|
||||
elif self.config.transport == "unix":
|
||||
return self._call_tool_with_retry(lambda: self._call_tool_http(tool_name, arguments))
|
||||
else:
|
||||
return self._call_tool_http(tool_name, arguments)
|
||||
|
||||
def _call_tool_with_retry(self, call: Any) -> Any:
|
||||
"""Retry transient MCP transport failures once after reconnecting."""
|
||||
if self.config.transport == "stdio":
|
||||
return call()
|
||||
|
||||
if self.config.transport not in {"unix", "sse"}:
|
||||
return call()
|
||||
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as original_error:
|
||||
logger.warning(
|
||||
"Retrying MCP tool call after transport error from '%s': %s",
|
||||
self.config.name,
|
||||
original_error,
|
||||
)
|
||||
self._reconnect()
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as retry_error:
|
||||
raise original_error from retry_error
|
||||
|
||||
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call tool via STDIO protocol using persistent session."""
|
||||
if not self._session:
|
||||
@@ -389,19 +512,35 @@ class MCPClient:
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
error_text = content_item.text
|
||||
raise RuntimeError(f"MCP tool '{tool_name}' failed: {error_text}")
|
||||
raise RuntimeError(
|
||||
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
|
||||
f"Tool '{tool_name}' failed: {error_text}"
|
||||
)
|
||||
|
||||
# Extract content
|
||||
# Extract content — preserve image blocks alongside text
|
||||
if result.content:
|
||||
# MCP returns content as a list of content items
|
||||
if len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
# Check if it's a text content item
|
||||
if hasattr(content_item, "text"):
|
||||
return content_item.text
|
||||
elif hasattr(content_item, "data"):
|
||||
return content_item.data
|
||||
return result.content
|
||||
text_parts: list[str] = []
|
||||
image_parts: list[dict[str, Any]] = []
|
||||
for item in result.content:
|
||||
if hasattr(item, "text"):
|
||||
text_parts.append(item.text)
|
||||
elif hasattr(item, "data") and hasattr(item, "mimeType"):
|
||||
# MCP ImageContent — preserve as structured image block
|
||||
image_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{item.mimeType};base64,{item.data}",
|
||||
},
|
||||
}
|
||||
)
|
||||
elif hasattr(item, "data"):
|
||||
text_parts.append(str(item.data))
|
||||
|
||||
text = "\n".join(text_parts) if text_parts else ""
|
||||
if image_parts:
|
||||
return {"_text": text, "_images": image_parts}
|
||||
return text if text else None
|
||||
|
||||
return None
|
||||
|
||||
@@ -427,24 +566,36 @@ class MCPClient:
|
||||
data = response.json()
|
||||
|
||||
if "error" in data:
|
||||
raise RuntimeError(f"Tool execution error: {data['error']}")
|
||||
raise RuntimeError(
|
||||
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
|
||||
f"Tool '{tool_name}' failed: {data['error']}"
|
||||
)
|
||||
|
||||
return data.get("result", {}).get("content", [])
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
|
||||
raise RuntimeError(
|
||||
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
|
||||
f"Failed to call tool via HTTP: Tool '{tool_name}' failed: {e}"
|
||||
) from e
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
"""Reconnect to the configured MCP server."""
|
||||
logger.info(f"Reconnecting to MCP server '{self.config.name}'...")
|
||||
self.disconnect()
|
||||
self.connect()
|
||||
|
||||
_CLEANUP_TIMEOUT = 10
|
||||
_THREAD_JOIN_TIMEOUT = 12
|
||||
|
||||
async def _cleanup_stdio_async(self) -> None:
|
||||
"""Async cleanup for STDIO session and context managers.
|
||||
"""Async cleanup for persistent MCP session and context managers.
|
||||
|
||||
Cleanup order is critical:
|
||||
- The session must be closed BEFORE the stdio_context because the session
|
||||
depends on the streams provided by stdio_context.
|
||||
- This mirrors the initialization order in _connect_stdio(), where
|
||||
stdio_context is entered first (providing streams), then the session is
|
||||
created with those streams and entered.
|
||||
- The session must be closed BEFORE the transport context manager because the
|
||||
session depends on the streams provided by that context.
|
||||
- This mirrors the initialization order in _connect_stdio() / _connect_sse(),
|
||||
where the transport context is entered first (providing streams), then the
|
||||
session is created with those streams and entered.
|
||||
- Do not change this ordering without carefully considering these dependencies.
|
||||
"""
|
||||
# First: close session (depends on stdio_context streams)
|
||||
@@ -477,6 +628,16 @@ class MCPClient:
|
||||
finally:
|
||||
self._stdio_context = None
|
||||
|
||||
try:
|
||||
if self._sse_context:
|
||||
await self._sse_context.__aexit__(None, None, None)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("SSE context cleanup was cancelled; proceeding with best-effort shutdown")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing SSE context: {e}")
|
||||
finally:
|
||||
self._sse_context = None
|
||||
|
||||
# Third: close errlog file handle if we opened one
|
||||
if self._errlog_handle is not None:
|
||||
try:
|
||||
@@ -552,6 +713,7 @@ class MCPClient:
|
||||
# Setting None to None is safe and ensures clean state.
|
||||
self._session = None
|
||||
self._stdio_context = None
|
||||
self._sse_context = None
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._loop = None
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
"""Shared MCP client connection management."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRANSITION_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class MCPConnectionManager:
|
||||
"""Process-wide MCP client pool keyed by server name."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool: dict[str, MCPClient] = {}
|
||||
self._refcounts: dict[str, int] = {}
|
||||
self._configs: dict[str, MCPServerConfig] = {}
|
||||
self._pool_lock = threading.Lock()
|
||||
self._transitions: dict[str, threading.Event] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPConnectionManager":
|
||||
"""Return the process-level singleton instance."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
def _is_connected(client: MCPClient | None) -> bool:
|
||||
return bool(client and getattr(client, "_connected", False))
|
||||
|
||||
def has_connection(self, server_name: str) -> bool:
|
||||
"""Return True when a live pooled connection exists for ``server_name``."""
|
||||
with self._pool_lock:
|
||||
return self._is_connected(self._pool.get(server_name))
|
||||
|
||||
def acquire(self, config: MCPServerConfig) -> MCPClient:
|
||||
"""Get or create a shared connection and increment its refcount."""
|
||||
server_name = config.name
|
||||
|
||||
while True:
|
||||
should_connect = False
|
||||
transition_event: threading.Event | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
client = self._pool.get(server_name)
|
||||
if self._is_connected(client) and server_name not in self._transitions:
|
||||
new_refcount = self._refcounts.get(server_name, 0) + 1
|
||||
self._refcounts[server_name] = new_refcount
|
||||
self._configs[server_name] = config
|
||||
logger.debug(
|
||||
"Reusing pooled connection for MCP server '%s' (refcount=%d)",
|
||||
server_name,
|
||||
new_refcount,
|
||||
)
|
||||
return client
|
||||
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
self._configs[server_name] = config
|
||||
should_connect = True
|
||||
|
||||
if not should_connect:
|
||||
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
|
||||
logger.warning(
|
||||
"Timed out waiting for transition on MCP server '%s', "
|
||||
"forcing cleanup and retrying",
|
||||
server_name,
|
||||
)
|
||||
with self._pool_lock:
|
||||
stuck = self._transitions.get(server_name)
|
||||
if stuck is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
continue
|
||||
|
||||
logger.info("Connecting to MCP server '%s'", server_name)
|
||||
client = MCPClient(config)
|
||||
try:
|
||||
client.connect()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s'",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
if (
|
||||
server_name not in self._pool
|
||||
and self._refcounts.get(server_name, 0) <= 0
|
||||
):
|
||||
self._configs.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool[server_name] = client
|
||||
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
|
||||
self._configs[server_name] = config
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
logger.info(
|
||||
"Connected to MCP server '%s' (refcount=1)",
|
||||
server_name,
|
||||
)
|
||||
return client
|
||||
|
||||
# Lost the transition race, clean up and retry
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Error disconnecting stale client for '%s'",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def release(self, server_name: str) -> None:
|
||||
"""Decrement refcount and disconnect when the last user releases."""
|
||||
while True:
|
||||
disconnect_client: MCPClient | None = None
|
||||
transition_event: threading.Event | None = None
|
||||
should_disconnect = False
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
refcount = self._refcounts.get(server_name, 0)
|
||||
if refcount <= 0:
|
||||
return
|
||||
if refcount > 1:
|
||||
self._refcounts[server_name] = refcount - 1
|
||||
logger.debug(
|
||||
"Released MCP server '%s' (refcount=%d)",
|
||||
server_name,
|
||||
refcount - 1,
|
||||
)
|
||||
return
|
||||
|
||||
disconnect_client = self._pool.pop(server_name, None)
|
||||
self._refcounts.pop(server_name, None)
|
||||
self._configs.pop(server_name, None)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
should_disconnect = True
|
||||
|
||||
if not should_disconnect:
|
||||
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
|
||||
logger.warning(
|
||||
"Timed out waiting for transition on '%s' during release, forcing cleanup",
|
||||
server_name,
|
||||
)
|
||||
with self._pool_lock:
|
||||
stuck = self._transitions.get(server_name)
|
||||
if stuck is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
continue
|
||||
|
||||
try:
|
||||
if disconnect_client is not None:
|
||||
disconnect_client.disconnect()
|
||||
logger.info(
|
||||
"Disconnected MCP server '%s' (last reference released)",
|
||||
server_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error disconnecting MCP server '%s' during release",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return
|
||||
|
||||
def health_check(self, server_name: str) -> bool:
|
||||
"""Return True when the pooled connection appears healthy."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
client = self._pool.get(server_name)
|
||||
config = self._configs.get(server_name)
|
||||
break
|
||||
|
||||
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
|
||||
logger.warning(
|
||||
"Timed out waiting for transition on '%s' during health check",
|
||||
server_name,
|
||||
)
|
||||
return False
|
||||
|
||||
if client is None or config is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
match config.transport:
|
||||
case "stdio":
|
||||
client.list_tools()
|
||||
return True
|
||||
case "http":
|
||||
if not config.url:
|
||||
return False
|
||||
with httpx.Client(
|
||||
base_url=config.url,
|
||||
headers=config.headers,
|
||||
timeout=5.0,
|
||||
) as http_client:
|
||||
response = http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
return True
|
||||
case "sse":
|
||||
client.list_tools()
|
||||
return True
|
||||
case "unix":
|
||||
if not config.socket_path:
|
||||
return False
|
||||
with httpx.Client(
|
||||
base_url=config.url or "http://localhost",
|
||||
headers=config.headers,
|
||||
timeout=5.0,
|
||||
transport=httpx.HTTPTransport(uds=config.socket_path),
|
||||
) as http_client:
|
||||
response = http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
return True
|
||||
case _:
|
||||
logger.warning(
|
||||
"Unknown transport '%s' for health check on '%s'",
|
||||
config.transport,
|
||||
server_name,
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Health check failed for MCP server '%s'",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
def reconnect(self, server_name: str) -> MCPClient:
|
||||
"""Force a disconnect and replace the pooled client with a fresh one."""
|
||||
while True:
|
||||
transition_event: threading.Event | None = None
|
||||
old_client: MCPClient | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
config = self._configs.get(server_name)
|
||||
if config is None:
|
||||
raise KeyError(f"Unknown MCP server: {server_name}")
|
||||
old_client = self._pool.get(server_name)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
break
|
||||
|
||||
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
|
||||
logger.warning(
|
||||
"Timed out waiting for transition on '%s' during reconnect, forcing cleanup",
|
||||
server_name,
|
||||
)
|
||||
with self._pool_lock:
|
||||
stuck = self._transitions.get(server_name)
|
||||
if stuck is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
|
||||
# Disconnect old client safely
|
||||
if old_client is not None:
|
||||
try:
|
||||
old_client.disconnect()
|
||||
logger.info("Disconnected old client for '%s'", server_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error disconnecting old client for '%s' during reconnect",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info("Reconnecting MCP server '%s'", server_name)
|
||||
new_client = MCPClient(config)
|
||||
try:
|
||||
new_client.connect()
|
||||
except Exception:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool.pop(server_name, None)
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
current_refcount = self._refcounts.get(server_name, 0)
|
||||
if current_refcount <= 0:
|
||||
# All holders released during reconnect. Discard the
|
||||
# new client instead of creating a phantom reference.
|
||||
# Caller should acquire() fresh if needed.
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
logger.info(
|
||||
"Reconnected MCP server '%s' but refcount dropped to 0, "
|
||||
"discarding new client",
|
||||
server_name,
|
||||
)
|
||||
try:
|
||||
new_client.disconnect()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Error disconnecting discarded client for '%s'",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
raise KeyError(
|
||||
f"MCP server '{server_name}' was fully released during reconnect"
|
||||
)
|
||||
|
||||
self._pool[server_name] = new_client
|
||||
self._configs[server_name] = config
|
||||
self._refcounts[server_name] = current_refcount
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
logger.info(
|
||||
"Reconnected MCP server '%s' (refcount=%d)",
|
||||
server_name,
|
||||
current_refcount,
|
||||
)
|
||||
return new_client
|
||||
|
||||
try:
|
||||
new_client.disconnect()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Error disconnecting stale client for '%s' after reconnect race",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return self.acquire(config)
|
||||
|
||||
def cleanup_all(self) -> None:
|
||||
"""Disconnect all pooled clients and clear manager state."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
if self._transitions:
|
||||
pending = list(self._transitions.values())
|
||||
else:
|
||||
cleanup_events = {name: threading.Event() for name in self._pool}
|
||||
clients = list(self._pool.items())
|
||||
self._transitions.update(cleanup_events)
|
||||
self._pool.clear()
|
||||
self._refcounts.clear()
|
||||
self._configs.clear()
|
||||
break
|
||||
|
||||
all_resolved = all(event.wait(timeout=_TRANSITION_TIMEOUT) for event in pending)
|
||||
if not all_resolved:
|
||||
logger.warning(
|
||||
"Timed out waiting for pending transitions during cleanup, "
|
||||
"forcing cleanup of stuck transitions",
|
||||
)
|
||||
with self._pool_lock:
|
||||
for sn, evt in list(self._transitions.items()):
|
||||
if not evt.is_set():
|
||||
self._transitions.pop(sn, None)
|
||||
evt.set()
|
||||
|
||||
logger.info("Cleaning up %d pooled MCP connections", len(clients))
|
||||
for server_name, client in clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
logger.debug("Disconnected MCP server '%s' during cleanup", server_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error disconnecting MCP server '%s' during cleanup",
|
||||
server_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
with self._pool_lock:
|
||||
for server_name, event in cleanup_events.items():
|
||||
current = self._transitions.get(server_name)
|
||||
if current is event:
|
||||
self._transitions.pop(server_name, None)
|
||||
event.set()
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Structured error codes and exceptions for MCP server operations."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MCPErrorCode(Enum):
|
||||
"""Standardized error codes for MCP operations."""
|
||||
|
||||
MCP_INSTALL_FAILED = "MCP_INSTALL_FAILED"
|
||||
MCP_AUTH_MISSING = "MCP_AUTH_MISSING"
|
||||
MCP_CONNECT_TIMEOUT = "MCP_CONNECT_TIMEOUT"
|
||||
MCP_TOOL_NOT_FOUND = "MCP_TOOL_NOT_FOUND"
|
||||
MCP_PROTOCOL_MISMATCH = "MCP_PROTOCOL_MISMATCH"
|
||||
MCP_VERSION_CONFLICT = "MCP_VERSION_CONFLICT"
|
||||
MCP_HEALTH_FAILED = "MCP_HEALTH_FAILED"
|
||||
|
||||
|
||||
class MCPError(ValueError):
|
||||
"""Base exception for all structured MCP errors."""
|
||||
|
||||
def __init__(self, code: MCPErrorCode, what: str, why: str, fix: str):
|
||||
self.code = code
|
||||
self.what = what
|
||||
self.why = why
|
||||
self.fix = fix
|
||||
self.message = (
|
||||
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class MCPToolNotFoundError(MCPError):
|
||||
def __init__(self, server: str, tool_name: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_TOOL_NOT_FOUND,
|
||||
what=f"Tool '{tool_name}' not found on server '{server}'",
|
||||
why=f"The server '{server}' does not expose a tool named '{tool_name}'.",
|
||||
fix=f"Run 'hive mcp inspect {server}' to view available tools.",
|
||||
)
|
||||
|
||||
|
||||
class MCPConnectTimeoutError(MCPError):
|
||||
def __init__(self, server: str, transport: str, timeout_sec: int):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_CONNECT_TIMEOUT,
|
||||
what=f"Connection timed out while starting server '{server}'",
|
||||
why=f"The {transport} transport did not respond within {timeout_sec} seconds.",
|
||||
fix=f"Check if the server is running. Run 'hive mcp doctor {server}' for diagnostics.",
|
||||
)
|
||||
|
||||
|
||||
class MCPAuthError(MCPError):
|
||||
def __init__(self, server: str, env_var: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_AUTH_MISSING,
|
||||
what=f"Authentication failed for server '{server}'",
|
||||
why=f"The required environment variable '{env_var}' is missing or empty.",
|
||||
fix=f"Run: hive mcp config {server} --set {env_var}=<your-token>",
|
||||
)
|
||||
|
||||
|
||||
class MCPInstallError(MCPError):
|
||||
def __init__(self, server: str, why: str, fix: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Could not install MCP server '{server}'",
|
||||
why=why,
|
||||
fix=fix,
|
||||
)
|
||||
|
||||
|
||||
class MCPProtocolMismatchError(MCPError):
|
||||
def __init__(self, server: str, detail: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_PROTOCOL_MISMATCH,
|
||||
what=f"Protocol mismatch with server '{server}'",
|
||||
why=detail,
|
||||
fix=f"Check the MCP SDK version required by '{server}' matches your installation.",
|
||||
)
|
||||
|
||||
|
||||
class MCPVersionConflictError(MCPError):
|
||||
def __init__(self, server: str, detail: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_VERSION_CONFLICT,
|
||||
what=f"Version conflict with server '{server}'",
|
||||
why=detail,
|
||||
fix="Update or pin the MCP server package to a compatible version.",
|
||||
)
|
||||
|
||||
|
||||
class MCPHealthCheckError(MCPError):
|
||||
def __init__(self, server: str, detail: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what=f"Health check failed for server '{server}'",
|
||||
why=detail,
|
||||
fix=f"Run 'hive mcp doctor {server}' to diagnose the issue.",
|
||||
)
|
||||
@@ -0,0 +1,904 @@
|
||||
"""MCP Server Registry: local state management for installed MCP servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import tomllib
|
||||
from datetime import UTC, datetime
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
from framework.runner.mcp_errors import (
|
||||
MCPError,
|
||||
MCPErrorCode,
|
||||
MCPInstallError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_INDEX_URL = (
|
||||
"https://raw.githubusercontent.com/aden-hive/hive-mcp-registry/main/registry_index.json"
|
||||
)
|
||||
DEFAULT_REFRESH_INTERVAL_HOURS = 24
|
||||
_LAST_FETCHED_FILENAME = "last_fetched"
|
||||
_LEGACY_LAST_FETCHED_FILENAME = "last_fetched.json"
|
||||
|
||||
_DEFAULT_CONFIG = {
|
||||
"index_url": DEFAULT_INDEX_URL,
|
||||
"refresh_interval_hours": DEFAULT_REFRESH_INTERVAL_HOURS,
|
||||
}
|
||||
|
||||
|
||||
class MCPRegistry:
|
||||
"""Manages local MCP server state in ~/.hive/mcp_registry/."""
|
||||
|
||||
def __init__(self, base_path: Path | None = None):
|
||||
self._base = base_path or Path.home() / ".hive" / "mcp_registry"
|
||||
self._installed_path = self._base / "installed.json"
|
||||
self._config_path = self._base / "config.json"
|
||||
self._cache_dir = self._base / "cache"
|
||||
|
||||
# ── Initialization ──────────────────────────────────────────────
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Create directory structure and default files if missing."""
|
||||
self._base.mkdir(parents=True, exist_ok=True)
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not self._config_path.exists():
|
||||
self._write_json(self._config_path, _DEFAULT_CONFIG)
|
||||
|
||||
if not self._installed_path.exists():
|
||||
self._write_json(self._installed_path, {"servers": {}})
|
||||
|
||||
# ── Internal I/O ────────────────────────────────────────────────
|
||||
|
||||
def _read_installed(self) -> dict:
|
||||
"""Read installed.json, initializing if needed."""
|
||||
if not self._installed_path.exists():
|
||||
self.initialize()
|
||||
return json.loads(self._installed_path.read_text(encoding="utf-8"))
|
||||
|
||||
def _write_installed(self, data: dict) -> None:
|
||||
"""Write installed.json."""
|
||||
self._write_json(self._installed_path, data)
|
||||
|
||||
def _read_config(self) -> dict:
|
||||
"""Read config.json."""
|
||||
if not self._config_path.exists():
|
||||
self.initialize()
|
||||
return json.loads(self._config_path.read_text(encoding="utf-8"))
|
||||
|
||||
def _read_cached_index(self) -> dict:
|
||||
"""Read cached registry_index.json."""
|
||||
index_path = self._cache_dir / "registry_index.json"
|
||||
if not index_path.exists():
|
||||
return {"servers": {}}
|
||||
return json.loads(index_path.read_text(encoding="utf-8"))
|
||||
|
||||
def _get_effective_manifest(
|
||||
self,
|
||||
name: str,
|
||||
entry: dict,
|
||||
cached_index: dict | None = None,
|
||||
) -> dict:
|
||||
"""Return the manifest currently in effect for an installed entry."""
|
||||
manifest = entry.get("manifest", {})
|
||||
if entry.get("source") != "registry":
|
||||
return manifest
|
||||
|
||||
index = cached_index or self._read_cached_index()
|
||||
cached_manifest = index.get("servers", {}).get(name)
|
||||
if cached_manifest is not None:
|
||||
return cached_manifest
|
||||
|
||||
# Fall back to persisted manifest data when the cache is unavailable.
|
||||
if isinstance(manifest, dict) and manifest:
|
||||
return manifest
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
"""Write JSON to file atomically (write to temp, fsync, rename)."""
|
||||
content = json.dumps(data, indent=2) + "\n"
|
||||
fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# ── add_local ───────────────────────────────────────────────────
|
||||
|
||||
def add_local(
|
||||
self,
|
||||
name: str,
|
||||
transport: str | None = None,
|
||||
manifest: dict | None = None,
|
||||
url: str | None = None,
|
||||
command: str | None = None,
|
||||
args: list[str] | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
socket_path: str | None = None,
|
||||
description: str = "",
|
||||
) -> dict:
|
||||
"""Register a local/running MCP server.
|
||||
|
||||
Can be called with an inline manifest dict, or with individual
|
||||
transport/url/command params that build a manifest automatically.
|
||||
"""
|
||||
data = self._read_installed()
|
||||
if name in data["servers"]:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Server '{name}' already exists",
|
||||
why="A server with this name is already registered locally.",
|
||||
fix=f"Run: hive mcp remove {name} — then add it again.",
|
||||
)
|
||||
|
||||
if manifest is not None:
|
||||
# Inline manifest provided directly
|
||||
manifest = {**manifest, "name": name}
|
||||
transport_config = manifest.get("transport", {})
|
||||
transport = transport or transport_config.get("default", "stdio")
|
||||
if "transport" not in manifest:
|
||||
manifest["transport"] = {"supported": [transport], "default": transport}
|
||||
else:
|
||||
# Build manifest from individual params
|
||||
if not transport:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}'",
|
||||
why="transport is required when manifest is not provided.",
|
||||
fix="Pass --transport stdio|http|unix|sse when using hive mcp add.",
|
||||
)
|
||||
manifest = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"transport": {"supported": [transport], "default": transport},
|
||||
}
|
||||
match transport:
|
||||
case "http":
|
||||
if not url:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with http transport",
|
||||
why="url is required for http transport.",
|
||||
fix="Pass --url https://your-server to hive mcp add.",
|
||||
)
|
||||
manifest["http"] = {"url": url, "headers": headers or {}}
|
||||
case "stdio":
|
||||
if not command:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with stdio transport",
|
||||
why="command is required for stdio transport.",
|
||||
fix="Pass --command <executable> to hive mcp add.",
|
||||
)
|
||||
manifest["stdio"] = {
|
||||
"command": command,
|
||||
"args": args or [],
|
||||
"env": env or {},
|
||||
"cwd": cwd,
|
||||
}
|
||||
case "unix":
|
||||
if not socket_path:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with unix transport",
|
||||
why="socket_path is required for unix transport.",
|
||||
fix="Pass --socket-path /path/to/socket to hive mcp add.",
|
||||
)
|
||||
manifest["unix"] = {"socket_path": socket_path}
|
||||
manifest["http"] = {"url": url or "http://localhost"}
|
||||
case "sse":
|
||||
if not url:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with sse transport",
|
||||
why="url is required for sse transport.",
|
||||
fix="Pass --url https://your-server to hive mcp add.",
|
||||
)
|
||||
manifest["sse"] = {"url": url}
|
||||
case _:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}'",
|
||||
why=f"Unsupported transport: '{transport}'.",
|
||||
fix="Use one of: stdio, http, unix, sse.",
|
||||
)
|
||||
|
||||
entry = self._make_entry(
|
||||
source="local",
|
||||
manifest=manifest,
|
||||
transport=transport,
|
||||
installed_by="hive mcp add",
|
||||
)
|
||||
|
||||
data["servers"][name] = entry
|
||||
self._write_installed(data)
|
||||
logger.info("Registered local MCP server '%s' (%s)", name, transport)
|
||||
return entry
|
||||
|
||||
# ── install ─────────────────────────────────────────────────────
|
||||
|
||||
def install(self, name: str, transport: str | None = None, version: str | None = None) -> dict:
|
||||
"""Install a server from the cached remote registry index."""
|
||||
data = self._read_installed()
|
||||
if name in data["servers"]:
|
||||
raise MCPInstallError(
|
||||
server=name,
|
||||
why=f"Server '{name}' already exists in the registry.",
|
||||
fix=f"Run: hive mcp remove {name} — then install again.",
|
||||
)
|
||||
|
||||
index = self._read_cached_index()
|
||||
manifest = index.get("servers", {}).get(name)
|
||||
if manifest is None:
|
||||
raise MCPInstallError(
|
||||
server=name,
|
||||
why=f"Server '{name}' not found in registry index.",
|
||||
fix="Run: hive mcp update — then try again.",
|
||||
)
|
||||
|
||||
# Validate version if specified
|
||||
if version is not None:
|
||||
index_version = manifest.get("version")
|
||||
if index_version is None:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_VERSION_CONFLICT,
|
||||
what=f"Cannot pin version for '{name}'",
|
||||
why="The registry manifest has no version field.",
|
||||
fix="Run: hive mcp update — then omit --version to use latest.",
|
||||
)
|
||||
if index_version != version:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_VERSION_CONFLICT,
|
||||
what=f"Version mismatch for '{name}'",
|
||||
why=f"Requested {version} but index has {index_version}.",
|
||||
fix="Run: hive mcp update — or omit --version to use latest.",
|
||||
)
|
||||
|
||||
transport_config = manifest.get("transport", {})
|
||||
supported = transport_config.get("supported", [])
|
||||
if transport is not None:
|
||||
if supported and transport not in supported:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Transport '{transport}' not supported by '{name}'",
|
||||
why=f"Server supports: {supported}.",
|
||||
fix=f"Use one of the supported transports: {supported}.",
|
||||
)
|
||||
resolved_transport = transport
|
||||
else:
|
||||
resolved_transport = transport_config.get("default", "stdio")
|
||||
|
||||
entry = self._make_entry(
|
||||
source="registry",
|
||||
manifest=self._make_registry_manifest_snapshot(name, manifest),
|
||||
transport=resolved_transport,
|
||||
installed_by="hive mcp install",
|
||||
pinned=version is not None,
|
||||
auto_update=version is None,
|
||||
resolved_package_version=manifest.get("version"),
|
||||
)
|
||||
|
||||
data["servers"][name] = entry
|
||||
self._write_installed(data)
|
||||
logger.info(
|
||||
"Installed MCP server '%s' v%s from registry",
|
||||
name,
|
||||
entry["manifest_version"],
|
||||
)
|
||||
return entry
|
||||
|
||||
# ── remove / enable / disable ───────────────────────────────────
|
||||
|
||||
def remove(self, name: str) -> None:
|
||||
"""Remove a server from the registry."""
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot remove server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
del data["servers"][name]
|
||||
self._write_installed(data)
|
||||
logger.info("Removed MCP server '%s'", name)
|
||||
|
||||
def enable(self, name: str) -> None:
|
||||
"""Enable a disabled server."""
|
||||
self._set_enabled(name, enabled=True)
|
||||
|
||||
def disable(self, name: str) -> None:
|
||||
"""Disable a server without removing it."""
|
||||
self._set_enabled(name, enabled=False)
|
||||
|
||||
def _set_enabled(self, name: str, *, enabled: bool) -> None:
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot {'enable' if enabled else 'disable'} server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
data["servers"][name]["enabled"] = enabled
|
||||
self._write_installed(data)
|
||||
logger.info("%s MCP server '%s'", "Enabled" if enabled else "Disabled", name)
|
||||
|
||||
# ── list / get ──────────────────────────────────────────────────
|
||||
|
||||
def list_installed(self) -> list[dict]:
|
||||
"""Return all installed servers as a list of dicts with name included."""
|
||||
data = self._read_installed()
|
||||
return [{"name": name, **entry} for name, entry in data["servers"].items()]
|
||||
|
||||
def get_server(self, name: str) -> dict | None:
|
||||
"""Get a single installed server entry by name, or None if not found."""
|
||||
data = self._read_installed()
|
||||
entry = data["servers"].get(name)
|
||||
if entry is None:
|
||||
return None
|
||||
return {"name": name, **entry}
|
||||
|
||||
def list_available(self) -> list[dict]:
|
||||
"""List all servers from cached remote index."""
|
||||
index = self._read_cached_index()
|
||||
return [{"name": name, **m} for name, m in index.get("servers", {}).items()]
|
||||
|
||||
# ── set_override ────────────────────────────────────────────────
|
||||
|
||||
def set_override(
|
||||
self,
|
||||
name: str,
|
||||
key: str,
|
||||
value: str,
|
||||
override_type: Literal["env", "headers"] = "env",
|
||||
) -> None:
|
||||
"""Set an env or header override for a server."""
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot set override for server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
if override_type not in ("env", "headers"):
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Invalid override type '{override_type}' for server '{name}'",
|
||||
why="Override type must be 'env' or 'headers'.",
|
||||
fix="Use --type env or --type headers.",
|
||||
)
|
||||
data["servers"][name]["overrides"][override_type][key] = value
|
||||
self._write_installed(data)
|
||||
logger.info("Set %s override %s for MCP server '%s'", override_type, key, name)
|
||||
|
||||
# ── search ──────────────────────────────────────────────────────
|
||||
|
||||
def search(self, query: str) -> list[dict]:
|
||||
"""Search registry index by name, tag, description, or tool name."""
|
||||
query_lower = query.lower()
|
||||
index = self._read_cached_index()
|
||||
matches = []
|
||||
|
||||
for name, manifest in index.get("servers", {}).items():
|
||||
if self._matches_query(name, manifest, query_lower):
|
||||
matches.append({"name": name, **manifest})
|
||||
|
||||
return matches
|
||||
|
||||
@staticmethod
|
||||
def _matches_query(name: str, manifest: dict, query: str) -> bool:
|
||||
"""Check if a manifest matches a search query."""
|
||||
if query in name.lower():
|
||||
return True
|
||||
|
||||
description = manifest.get("description", "")
|
||||
if query in description.lower():
|
||||
return True
|
||||
|
||||
for tag in manifest.get("tags", []):
|
||||
if query in tag.lower():
|
||||
return True
|
||||
|
||||
for tool in manifest.get("tools", []):
|
||||
tool_name = tool.get("name", "") if isinstance(tool, dict) else str(tool)
|
||||
if query in tool_name.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# ── update_index ────────────────────────────────────────────────
|
||||
|
||||
def is_index_stale(self) -> bool:
|
||||
"""Check if the cached registry index needs refreshing."""
|
||||
last_fetched_path = self._cache_dir / _LAST_FETCHED_FILENAME
|
||||
legacy_path = self._cache_dir / _LEGACY_LAST_FETCHED_FILENAME
|
||||
if not last_fetched_path.exists() and not legacy_path.exists():
|
||||
return True
|
||||
|
||||
try:
|
||||
path = last_fetched_path if last_fetched_path.exists() else legacy_path
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
last_fetched = datetime.fromisoformat(data["timestamp"])
|
||||
config = self._read_config()
|
||||
interval_hours = config.get("refresh_interval_hours", DEFAULT_REFRESH_INTERVAL_HOURS)
|
||||
age_hours = (datetime.now(UTC) - last_fetched).total_seconds() / 3600
|
||||
return age_hours >= interval_hours
|
||||
except (KeyError, ValueError, OSError):
|
||||
return True
|
||||
|
||||
def update_index(self) -> int:
|
||||
"""Fetch the latest registry index from remote and cache it.
|
||||
|
||||
Returns the number of servers in the index.
|
||||
"""
|
||||
config = self._read_config()
|
||||
url = config.get("index_url", DEFAULT_INDEX_URL)
|
||||
|
||||
response = httpx.get(url, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
index = response.json()
|
||||
|
||||
self._write_json(self._cache_dir / "registry_index.json", index)
|
||||
# Write last_fetched atomically too
|
||||
self._write_json(
|
||||
self._cache_dir / _LAST_FETCHED_FILENAME,
|
||||
{"timestamp": datetime.now(UTC).isoformat()},
|
||||
)
|
||||
|
||||
server_count = len(index.get("servers", {}))
|
||||
logger.info("Updated registry index: %d servers available", server_count)
|
||||
return server_count
|
||||
|
||||
# ── load_agent_selection ────────────────────────────────────────
|
||||
|
||||
def load_agent_selection(self, agent_path: Path) -> tuple[list[dict[str, Any]], int | None]:
|
||||
"""Load mcp_registry.json from an agent directory and resolve servers.
|
||||
|
||||
Returns:
|
||||
(server_config_dicts, max_tools) for :meth:`ToolRegistry.load_registry_servers`.
|
||||
``max_tools`` is ``None`` when omitted or invalid in JSON.
|
||||
"""
|
||||
registry_json_path = agent_path / "mcp_registry.json"
|
||||
if not registry_json_path.exists():
|
||||
return [], None
|
||||
|
||||
selection = json.loads(registry_json_path.read_text(encoding="utf-8"))
|
||||
|
||||
# Validate types at the JSON boundary. Bad fields are dropped with a
|
||||
# warning so the agent still starts (graceful degradation).
|
||||
expected_types: dict[str, type] = {
|
||||
"include": list,
|
||||
"tags": list,
|
||||
"exclude": list,
|
||||
"profile": str,
|
||||
"max_tools": int,
|
||||
"versions": dict,
|
||||
}
|
||||
validated: dict[str, Any] = {}
|
||||
for field, expected in expected_types.items():
|
||||
value = selection.get(field)
|
||||
if value is None:
|
||||
continue
|
||||
if not isinstance(value, expected):
|
||||
logger.warning(
|
||||
"mcp_registry.json: '%s' must be %s, got %s; ignoring",
|
||||
field,
|
||||
expected.__name__,
|
||||
type(value).__name__,
|
||||
)
|
||||
continue
|
||||
validated[field] = value
|
||||
|
||||
max_tools = validated.get("max_tools")
|
||||
configs = self.resolve_for_agent(
|
||||
include=validated.get("include"),
|
||||
tags=validated.get("tags"),
|
||||
exclude=validated.get("exclude"),
|
||||
profile=validated.get("profile"),
|
||||
max_tools=max_tools,
|
||||
versions=validated.get("versions"),
|
||||
)
|
||||
return [self._server_config_to_dict(c) for c in configs], max_tools
|
||||
|
||||
# ── resolve_for_agent ───────────────────────────────────────────
|
||||
|
||||
def resolve_for_agent(
|
||||
self,
|
||||
include: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
profile: str | None = None,
|
||||
max_tools: int | None = None,
|
||||
versions: dict[str, str] | None = None,
|
||||
) -> list[MCPServerConfig]:
|
||||
"""Resolve installed servers matching agent selection criteria.
|
||||
|
||||
Selection precedence per PRD section 7.2:
|
||||
1. profile expands to server names (union with include + tags)
|
||||
2. include adds explicit servers
|
||||
3. tags adds servers whose tags overlap
|
||||
4. exclude removes (always wins)
|
||||
5. Load order: include-order first, then alphabetical for tag/profile matches
|
||||
|
||||
Returns list of MCPServerConfig objects ready for ToolRegistry.
|
||||
"""
|
||||
data = self._read_installed()
|
||||
servers = data.get("servers", {})
|
||||
cached_index = self._read_cached_index()
|
||||
exclude_set = set(exclude or [])
|
||||
|
||||
# Phase 1: collect profile-matched servers (alphabetical)
|
||||
profile_matched: list[str] = []
|
||||
if profile:
|
||||
for name, entry in sorted(servers.items()):
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if profile == "all":
|
||||
profile_matched.append(name)
|
||||
else:
|
||||
manifest = self._get_effective_manifest(name, entry, cached_index)
|
||||
profiles = manifest.get("hive", {}).get("profiles", [])
|
||||
if profile in profiles:
|
||||
profile_matched.append(name)
|
||||
|
||||
# Phase 2: collect tag-matched servers (alphabetical)
|
||||
tag_matched: list[str] = []
|
||||
if tags:
|
||||
tag_set = set(tags)
|
||||
for name, entry in sorted(servers.items()):
|
||||
if name in exclude_set:
|
||||
continue
|
||||
manifest = self._get_effective_manifest(name, entry, cached_index)
|
||||
server_tags = set(manifest.get("tags", []))
|
||||
if tag_set & server_tags:
|
||||
tag_matched.append(name)
|
||||
|
||||
# Phase 3: build final ordered list
|
||||
# include-order first, then alphabetical for profile/tag matches
|
||||
selected: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for name in include or []:
|
||||
if name not in seen and name not in exclude_set:
|
||||
selected.append(name)
|
||||
seen.add(name)
|
||||
|
||||
for name in profile_matched:
|
||||
if name not in seen:
|
||||
selected.append(name)
|
||||
seen.add(name)
|
||||
|
||||
for name in tag_matched:
|
||||
if name not in seen:
|
||||
selected.append(name)
|
||||
seen.add(name)
|
||||
|
||||
# Build configs, tracking aggregate tool count for max_tools cap (FR-56)
|
||||
configs: list[MCPServerConfig] = []
|
||||
total_tools = 0
|
||||
for name in selected:
|
||||
entry = servers.get(name)
|
||||
if entry is None:
|
||||
logger.warning(
|
||||
"Server '%s' requested but not installed. Run: hive mcp install %s",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
if not entry.get("enabled", True):
|
||||
continue
|
||||
|
||||
manifest = self._get_effective_manifest(name, entry, cached_index)
|
||||
|
||||
# Check version pin (VC-6)
|
||||
if versions and name in versions:
|
||||
installed_version = entry.get("manifest_version", "0.0.0")
|
||||
pinned_version = versions[name]
|
||||
if installed_version != pinned_version:
|
||||
logger.warning(
|
||||
"Server '%s' version mismatch: installed=%s, pinned=%s. "
|
||||
"Run: hive mcp update %s",
|
||||
name,
|
||||
installed_version,
|
||||
pinned_version,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check tool count cap before adding (FR-56), using manifest tool list when present.
|
||||
# When ``tools`` is empty (e.g. ``add_local``), counts are unknown here—callers should
|
||||
# pass the same ``max_tools`` to ToolRegistry.load_registry_servers to cap registration.
|
||||
manifest_tools = manifest.get("tools", [])
|
||||
server_tool_count = len(manifest_tools)
|
||||
if max_tools is not None and server_tool_count == 0:
|
||||
logger.debug(
|
||||
"Server '%s' has no tools list in manifest; max_tools enforced at registration",
|
||||
name,
|
||||
)
|
||||
elif max_tools is not None and total_tools + server_tool_count > max_tools:
|
||||
logger.info(
|
||||
"Skipping server '%s' (%d tools): would exceed max_tools=%d",
|
||||
name,
|
||||
server_tool_count,
|
||||
max_tools,
|
||||
)
|
||||
continue
|
||||
|
||||
config = self._manifest_to_server_config(
|
||||
name,
|
||||
manifest,
|
||||
entry.get("overrides", {}),
|
||||
transport_override=entry.get("transport"),
|
||||
)
|
||||
if config is not None:
|
||||
configs.append(config)
|
||||
total_tools += server_tool_count
|
||||
|
||||
return configs
|
||||
|
||||
def _manifest_to_server_config(
|
||||
self,
|
||||
name: str,
|
||||
manifest: dict,
|
||||
overrides: dict | None = None,
|
||||
transport_override: str | None = None,
|
||||
) -> MCPServerConfig | None:
|
||||
"""Convert a manifest and overrides to MCPServerConfig."""
|
||||
overrides = overrides or {}
|
||||
transport_config = manifest.get("transport", {})
|
||||
transport = transport_override or transport_config.get("default", "stdio")
|
||||
description = manifest.get("description", "")
|
||||
|
||||
match transport:
|
||||
case "stdio":
|
||||
stdio_config = manifest.get("stdio", {})
|
||||
merged_env = {
|
||||
**stdio_config.get("env", {}),
|
||||
**overrides.get("env", {}),
|
||||
}
|
||||
return MCPServerConfig(
|
||||
name=name,
|
||||
transport="stdio",
|
||||
command=stdio_config.get("command"),
|
||||
args=stdio_config.get("args", []),
|
||||
env=merged_env,
|
||||
cwd=stdio_config.get("cwd"),
|
||||
description=description,
|
||||
)
|
||||
case "http":
|
||||
http_config = manifest.get("http", {})
|
||||
url = http_config.get("url", "")
|
||||
merged_headers = {
|
||||
**http_config.get("headers", {}),
|
||||
**overrides.get("headers", {}),
|
||||
}
|
||||
return MCPServerConfig(
|
||||
name=name,
|
||||
transport="http",
|
||||
url=url,
|
||||
headers=merged_headers,
|
||||
description=description,
|
||||
)
|
||||
case "unix":
|
||||
unix_config = manifest.get("unix", {})
|
||||
http_config = manifest.get("http", {})
|
||||
merged_headers = {
|
||||
**http_config.get("headers", {}),
|
||||
**overrides.get("headers", {}),
|
||||
}
|
||||
return MCPServerConfig(
|
||||
name=name,
|
||||
transport="unix",
|
||||
socket_path=unix_config.get("socket_path"),
|
||||
url=http_config.get("url") or "http://localhost",
|
||||
headers=merged_headers,
|
||||
description=description,
|
||||
)
|
||||
case "sse":
|
||||
sse_config = manifest.get("sse", {})
|
||||
merged_headers = {
|
||||
**sse_config.get("headers", {}),
|
||||
**overrides.get("headers", {}),
|
||||
}
|
||||
return MCPServerConfig(
|
||||
name=name,
|
||||
transport="sse",
|
||||
url=sse_config.get("url", ""),
|
||||
headers=merged_headers,
|
||||
description=description,
|
||||
)
|
||||
case _:
|
||||
logger.warning(
|
||||
"Unsupported transport '%s' for server '%s'",
|
||||
transport,
|
||||
name,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _server_config_to_dict(config: MCPServerConfig) -> dict[str, Any]:
|
||||
"""Convert MCPServerConfig to plain dict for ToolRegistry.register_mcp_server()."""
|
||||
return {
|
||||
"name": config.name,
|
||||
"transport": config.transport,
|
||||
"command": config.command,
|
||||
"args": config.args,
|
||||
"env": config.env,
|
||||
"cwd": config.cwd,
|
||||
"url": config.url,
|
||||
"headers": config.headers,
|
||||
"socket_path": config.socket_path,
|
||||
"description": config.description,
|
||||
}
|
||||
|
||||
# ── run_health_check ────────────────────────────────────────────
|
||||
|
||||
def health_check(self, name: str | None = None) -> dict | dict[str, dict]:
|
||||
"""Check health of installed server(s). Updates telemetry fields.
|
||||
|
||||
If name is None, checks all installed servers and returns
|
||||
a dict mapping server names to their health results.
|
||||
|
||||
"""
|
||||
if name is None:
|
||||
results = {}
|
||||
for server in self.list_installed():
|
||||
results[server["name"]] = self.health_check(server["name"])
|
||||
return results
|
||||
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what=f"Cannot health-check server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
|
||||
entry = data["servers"][name]
|
||||
manifest = self._get_effective_manifest(name, entry)
|
||||
config = self._manifest_to_server_config(
|
||||
name,
|
||||
manifest,
|
||||
entry.get("overrides", {}),
|
||||
transport_override=entry.get("transport"),
|
||||
)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"name": name,
|
||||
"status": "unknown",
|
||||
"tools": 0,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if config is None:
|
||||
transport = entry.get("transport", "unknown")
|
||||
result["status"] = "unhealthy"
|
||||
result["error"] = f"Unsupported transport '{transport}'"
|
||||
entry["last_health_status"] = "unhealthy"
|
||||
entry["last_error"] = result["error"]
|
||||
entry["last_health_check_at"] = now
|
||||
self._write_installed(data)
|
||||
return result
|
||||
|
||||
manager = MCPConnectionManager.get_instance()
|
||||
|
||||
try:
|
||||
if manager.has_connection(name):
|
||||
is_healthy = manager.health_check(name)
|
||||
if not is_healthy:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what=f"Health check failed for server '{name}'",
|
||||
why="Shared MCP connection reported unhealthy.",
|
||||
fix=f"Run: hive mcp doctor {name} — for diagnostics.",
|
||||
)
|
||||
pooled_client = manager.acquire(config)
|
||||
try:
|
||||
tools = pooled_client.list_tools()
|
||||
finally:
|
||||
manager.release(name)
|
||||
else:
|
||||
with MCPClient(config) as client:
|
||||
tools = client.list_tools()
|
||||
|
||||
result["status"] = "healthy"
|
||||
result["tools"] = len(tools)
|
||||
entry["last_health_status"] = "healthy"
|
||||
entry["last_error"] = None
|
||||
entry["last_validated_with_hive_version"] = self._get_hive_version()
|
||||
except Exception as exc:
|
||||
result["status"] = "unhealthy"
|
||||
result["error"] = str(exc)
|
||||
entry["last_health_status"] = "unhealthy"
|
||||
entry["last_error"] = str(exc)
|
||||
|
||||
entry["last_health_check_at"] = now
|
||||
self._write_installed(data)
|
||||
return result
|
||||
|
||||
def run_health_check(self, name: str | None = None) -> dict | dict[str, dict]:
|
||||
"""Backward-compatible wrapper for the public health_check API."""
|
||||
return self.health_check(name)
|
||||
|
||||
@staticmethod
|
||||
def _get_hive_version() -> str:
|
||||
"""Get the current Hive version."""
|
||||
try:
|
||||
return version("framework")
|
||||
except PackageNotFoundError:
|
||||
project_toml = Path(__file__).resolve().parents[2] / "pyproject.toml"
|
||||
if not project_toml.exists():
|
||||
return "unknown"
|
||||
try:
|
||||
with project_toml.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
return data.get("project", {}).get("version", "unknown")
|
||||
except (tomllib.TOMLDecodeError, OSError):
|
||||
return "unknown"
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _make_entry(
|
||||
*,
|
||||
source: str,
|
||||
manifest: dict,
|
||||
transport: str,
|
||||
installed_by: str,
|
||||
pinned: bool = False,
|
||||
auto_update: bool = False,
|
||||
resolved_package_version: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a standard installed server entry."""
|
||||
now = datetime.now(UTC).isoformat()
|
||||
return {
|
||||
"source": source,
|
||||
"manifest_version": manifest.get("version", "0.0.0"),
|
||||
"manifest": manifest,
|
||||
"installed_at": now,
|
||||
"installed_by": installed_by,
|
||||
"transport": transport,
|
||||
"enabled": True,
|
||||
"pinned": pinned,
|
||||
"auto_update": auto_update,
|
||||
"resolved_package_version": resolved_package_version,
|
||||
"overrides": {"env": {}, "headers": {}},
|
||||
"last_health_check_at": None,
|
||||
"last_health_status": None,
|
||||
"last_error": None,
|
||||
"last_used_at": None,
|
||||
"last_validated_with_hive_version": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _make_registry_manifest_snapshot(name: str, manifest: dict) -> dict[str, Any]:
|
||||
"""Persist a full manifest snapshot for registry-installed servers."""
|
||||
manifest_snapshot = dict(manifest)
|
||||
manifest_snapshot["name"] = name
|
||||
return manifest_snapshot
|
||||
@@ -0,0 +1,906 @@
|
||||
"""CLI commands for MCP server registry management.
|
||||
|
||||
Commands:
|
||||
hive mcp install <name> Install a server from the registry
|
||||
hive mcp add Register a local/running MCP server
|
||||
hive mcp remove <name> Remove an installed server
|
||||
hive mcp enable <name> Enable a server
|
||||
hive mcp disable <name> Disable a server
|
||||
hive mcp list List installed servers
|
||||
hive mcp info <name> Show server details
|
||||
hive mcp config <name> Set env/header overrides
|
||||
hive mcp search <query> Search the registry index
|
||||
hive mcp health [name] Check server health
|
||||
hive mcp update Refresh index and update installed servers
|
||||
hive mcp update <name> Update a single installed server
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# ── Shared helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_registry(base_path: Path | None = None):
|
||||
"""Initialize and return an MCPRegistry instance."""
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
registry = MCPRegistry(base_path=base_path)
|
||||
registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
def _ensure_index_available(registry) -> bool:
|
||||
"""Ensure the registry index is cached locally.
|
||||
|
||||
If no index exists or the cache is stale, fetches a fresh copy.
|
||||
Returns True if a usable index exists, False otherwise.
|
||||
|
||||
Semantics:
|
||||
- Stale cache + refresh fails -> warn and continue with stale cache (True)
|
||||
- No cache + refresh fails -> hard fail (False)
|
||||
"""
|
||||
import httpx
|
||||
|
||||
cache_exists = (registry._cache_dir / "registry_index.json").exists()
|
||||
|
||||
if registry.is_index_stale():
|
||||
print("Updating registry index...", file=sys.stderr)
|
||||
try:
|
||||
count = registry.update_index()
|
||||
print(f"Registry index updated ({count} servers available).", file=sys.stderr)
|
||||
return True
|
||||
except (httpx.HTTPError, OSError) as exc:
|
||||
if cache_exists:
|
||||
print(
|
||||
f"Warning: failed to update registry index: {exc}\nUsing cached index.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return True
|
||||
print(
|
||||
f"Error: no registry index available and refresh failed: {exc}\n"
|
||||
"Check your network connection and try: hive mcp update",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
|
||||
return cache_exists
|
||||
|
||||
|
||||
_SECURITY_NOTICE = (
|
||||
"Registry servers run code on your machine. Only install servers you trust.\n"
|
||||
"Learn more: https://github.com/aden-hive/hive-mcp-registry"
|
||||
)
|
||||
_NOTICE_SENTINEL = ".security_notice_shown"
|
||||
|
||||
|
||||
def _print_security_notice_if_first_use(registry_base: Path) -> None:
|
||||
"""Print a one-time security notice on first registry install.
|
||||
|
||||
Only prints the notice. Call _mark_security_notice_shown() after
|
||||
a successful install to persist the sentinel.
|
||||
"""
|
||||
sentinel = registry_base / _NOTICE_SENTINEL
|
||||
if sentinel.exists():
|
||||
return
|
||||
print(f"\n {_SECURITY_NOTICE}\n", file=sys.stderr)
|
||||
|
||||
|
||||
def _mark_security_notice_shown(registry_base: Path) -> None:
|
||||
"""Persist the security notice sentinel after a successful install."""
|
||||
sentinel = registry_base / _NOTICE_SENTINEL
|
||||
try:
|
||||
sentinel.touch()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _prompt_for_missing_credentials(
|
||||
registry,
|
||||
name: str,
|
||||
manifest: dict,
|
||||
) -> None:
|
||||
"""Prompt for required credentials not already set in env or overrides."""
|
||||
credentials = manifest.get("credentials", [])
|
||||
if not credentials:
|
||||
return
|
||||
|
||||
server = registry.get_server(name)
|
||||
existing_overrides = server.get("overrides", {}).get("env", {}) if server else {}
|
||||
|
||||
prompted = False
|
||||
for cred in credentials:
|
||||
if not isinstance(cred, dict):
|
||||
continue
|
||||
env_var = cred.get("env_var", "")
|
||||
if not env_var:
|
||||
continue
|
||||
required = cred.get("required", False)
|
||||
if not required:
|
||||
continue
|
||||
|
||||
# Skip if already in environment or overrides
|
||||
if os.environ.get(env_var) or existing_overrides.get(env_var):
|
||||
continue
|
||||
|
||||
if not prompted:
|
||||
print(f"\n{name} requires credentials:", file=sys.stderr)
|
||||
prompted = True
|
||||
|
||||
description = cred.get("description", env_var)
|
||||
help_url = cred.get("help_url", "")
|
||||
help_hint = f" (get one at {help_url})" if help_url else ""
|
||||
|
||||
try:
|
||||
value = input(f" {description}{help_hint}\n {env_var}: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nSkipped credential prompting.", file=sys.stderr)
|
||||
return
|
||||
|
||||
if value:
|
||||
registry.set_override(name, env_var, value, override_type="env")
|
||||
|
||||
|
||||
def _parse_key_value_pairs(values: list[str]) -> dict[str, str]:
|
||||
"""Parse KEY=VAL pairs from CLI args. Raises ValueError on bad format."""
|
||||
result = {}
|
||||
for item in values:
|
||||
if "=" not in item:
|
||||
raise ValueError(
|
||||
f"Invalid format: '{item}'. Expected KEY=VALUE.\n"
|
||||
f"Example: --set JIRA_API_TOKEN=abc123"
|
||||
)
|
||||
key, _, value = item.partition("=")
|
||||
if not key:
|
||||
raise ValueError(f"Invalid format: '{item}'. Key cannot be empty.")
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _find_agents_using_server(registry, name: str) -> list[str]:
|
||||
"""Scan agent directories for mcp_registry.json files that would load a server.
|
||||
|
||||
Uses MCPRegistry.load_agent_selection() to resolve actual selection logic
|
||||
so results stay consistent with runtime behavior.
|
||||
"""
|
||||
agent_dirs: list[Path] = []
|
||||
# parents: [0]=runner, [1]=framework, [2]=core, [3]=hive (project root)
|
||||
# NOTE: This path arithmetic assumes running from the source tree layout.
|
||||
# It will not resolve correctly if installed via pip into site-packages.
|
||||
project_root = Path(__file__).resolve().parents[3]
|
||||
core_dir = Path(__file__).resolve().parents[2]
|
||||
|
||||
candidates = [
|
||||
project_root / "exports",
|
||||
core_dir / "exports",
|
||||
core_dir / "framework" / "agents",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.is_dir():
|
||||
for child in candidate.iterdir():
|
||||
if child.is_dir():
|
||||
agent_dirs.append(child)
|
||||
|
||||
matches = []
|
||||
for agent_dir in agent_dirs:
|
||||
registry_json = agent_dir / "mcp_registry.json"
|
||||
if not registry_json.exists():
|
||||
continue
|
||||
try:
|
||||
configs = registry.load_agent_selection(agent_dir)
|
||||
resolved_names = {c["name"] for c in configs}
|
||||
if name in resolved_names:
|
||||
matches.append(str(agent_dir))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _render_installed_table(entries: list[dict]) -> None:
|
||||
"""Render installed servers as a formatted table."""
|
||||
if not entries:
|
||||
print("No servers installed.")
|
||||
print("Run 'hive mcp install <name>' or 'hive mcp add' to get started.")
|
||||
return
|
||||
|
||||
# Column widths
|
||||
name_w = max(len(e["name"]) for e in entries)
|
||||
name_w = max(name_w, 4)
|
||||
transport_w = max(len(e.get("transport", "")) for e in entries)
|
||||
transport_w = max(transport_w, 9)
|
||||
|
||||
header = (
|
||||
f" {'NAME':<{name_w}} "
|
||||
f"{'TRANSPORT':<{transport_w}} "
|
||||
f"{'ENABLED':<7} "
|
||||
f"{'HEALTH':<9} "
|
||||
f"{'TOOLS':<5} "
|
||||
f"{'TRUST':<10} "
|
||||
f"{'SOURCE'}"
|
||||
)
|
||||
print(header)
|
||||
print(" " + "─" * (len(header) - 2))
|
||||
|
||||
for entry in entries:
|
||||
enabled = "yes" if entry.get("enabled", True) else "no"
|
||||
health = entry.get("last_health_status") or "unknown"
|
||||
health_sym = {"healthy": "✓", "unhealthy": "✗"}.get(health, "●")
|
||||
source = entry.get("source", "")
|
||||
manifest = entry.get("manifest", {})
|
||||
tools_count = str(len(manifest.get("tools", [])))
|
||||
trust_tier = manifest.get("status", "")
|
||||
print(
|
||||
f" {entry['name']:<{name_w}} "
|
||||
f"{entry.get('transport', ''):<{transport_w}} "
|
||||
f"{enabled:<7} "
|
||||
f"{health_sym} {health:<7} "
|
||||
f"{tools_count:<5} "
|
||||
f"{trust_tier:<10} "
|
||||
f"{source}"
|
||||
)
|
||||
|
||||
|
||||
def _render_available_table(entries: list[dict]) -> None:
|
||||
"""Render available registry servers as a formatted table."""
|
||||
if not entries:
|
||||
print("No servers in registry index.")
|
||||
print("Run 'hive mcp update' to refresh the index.")
|
||||
return
|
||||
|
||||
name_w = max(len(e["name"]) for e in entries)
|
||||
name_w = max(name_w, 4)
|
||||
|
||||
header = f" {'NAME':<{name_w}} {'VERSION':<9} {'STATUS':<10} DESCRIPTION"
|
||||
print(header)
|
||||
print(" " + "─" * (len(header) - 2))
|
||||
|
||||
for entry in entries:
|
||||
version = entry.get("version", "")
|
||||
status = entry.get("status", "community")
|
||||
desc = entry.get("description", "")
|
||||
# Truncate long descriptions
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
print(f" {entry['name']:<{name_w}} {version:<9} {status:<10} {desc}")
|
||||
|
||||
|
||||
def _mask_overrides(overrides: dict) -> dict:
|
||||
"""Replace override values with '<set>' markers. Shared by all output paths."""
|
||||
masked: dict[str, dict[str, str]] = {}
|
||||
if overrides.get("env"):
|
||||
masked["env"] = dict.fromkeys(overrides["env"], "<set>")
|
||||
else:
|
||||
masked["env"] = {}
|
||||
if overrides.get("headers"):
|
||||
masked["headers"] = dict.fromkeys(overrides["headers"], "<set>")
|
||||
else:
|
||||
masked["headers"] = {}
|
||||
return masked
|
||||
|
||||
|
||||
def _emit_json(data: Any) -> None:
|
||||
"""Print data as formatted JSON."""
|
||||
print(json.dumps(data, indent=2, default=str))
|
||||
|
||||
|
||||
# ── Command registration ───────────────────────────────────────────
|
||||
|
||||
|
||||
def register_mcp_commands(subparsers) -> None:
|
||||
"""Register the ``hive mcp`` subcommand group."""
|
||||
mcp_parser = subparsers.add_parser("mcp", help="Manage MCP servers")
|
||||
mcp_sub = mcp_parser.add_subparsers(dest="mcp_command", required=True)
|
||||
|
||||
# ── install ──
|
||||
install_p = mcp_sub.add_parser("install", help="Install a server from the registry")
|
||||
install_p.add_argument("name", help="Server name in the registry")
|
||||
install_p.add_argument(
|
||||
"--version", dest="version", default=None, help="Pin to a specific version"
|
||||
)
|
||||
install_p.add_argument(
|
||||
"--transport", default=None, help="Override default transport (stdio, http, unix, sse)"
|
||||
)
|
||||
install_p.set_defaults(func=cmd_mcp_install)
|
||||
|
||||
# ── add ──
|
||||
add_p = mcp_sub.add_parser("add", help="Register a local/running MCP server")
|
||||
add_p.add_argument("--name", required=False, help="Server name")
|
||||
add_p.add_argument(
|
||||
"--transport",
|
||||
choices=["stdio", "http", "unix", "sse"],
|
||||
default=None,
|
||||
help="Transport type",
|
||||
)
|
||||
add_p.add_argument("--url", default=None, help="Server URL (http, unix, sse)")
|
||||
add_p.add_argument("--command", default=None, help="Command to run (stdio)")
|
||||
add_p.add_argument("--args", nargs="*", default=None, help="Command arguments (stdio)")
|
||||
add_p.add_argument("--socket-path", default=None, help="Unix socket path")
|
||||
add_p.add_argument("--description", default="", help="Server description")
|
||||
add_p.add_argument("--from", dest="from_manifest", default=None, help="Path to manifest.json")
|
||||
add_p.set_defaults(func=cmd_mcp_add)
|
||||
|
||||
# ── remove ──
|
||||
remove_p = mcp_sub.add_parser("remove", help="Remove an installed server")
|
||||
remove_p.add_argument("name", help="Server name")
|
||||
remove_p.set_defaults(func=cmd_mcp_remove)
|
||||
|
||||
# ── enable ──
|
||||
enable_p = mcp_sub.add_parser("enable", help="Enable a disabled server")
|
||||
enable_p.add_argument("name", help="Server name")
|
||||
enable_p.set_defaults(func=cmd_mcp_enable)
|
||||
|
||||
# ── disable ──
|
||||
disable_p = mcp_sub.add_parser("disable", help="Disable a server without removing it")
|
||||
disable_p.add_argument("name", help="Server name")
|
||||
disable_p.set_defaults(func=cmd_mcp_disable)
|
||||
|
||||
# ── list ──
|
||||
list_p = mcp_sub.add_parser("list", help="List servers")
|
||||
list_p.add_argument(
|
||||
"--available", action="store_true", help="Show available servers from registry"
|
||||
)
|
||||
list_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
list_p.set_defaults(func=cmd_mcp_list)
|
||||
|
||||
# ── info ──
|
||||
info_p = mcp_sub.add_parser("info", help="Show server details")
|
||||
info_p.add_argument("name", help="Server name")
|
||||
info_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
info_p.set_defaults(func=cmd_mcp_info)
|
||||
|
||||
# ── config ──
|
||||
config_p = mcp_sub.add_parser("config", help="Set server configuration overrides")
|
||||
config_p.add_argument("name", help="Server name")
|
||||
config_p.add_argument(
|
||||
"--set",
|
||||
dest="set_env",
|
||||
nargs="+",
|
||||
metavar="KEY=VAL",
|
||||
help="Set environment variable overrides",
|
||||
)
|
||||
config_p.add_argument(
|
||||
"--set-header", dest="set_header", nargs="+", metavar="KEY=VAL", help="Set header overrides"
|
||||
)
|
||||
config_p.set_defaults(func=cmd_mcp_config)
|
||||
|
||||
# ── search ──
|
||||
search_p = mcp_sub.add_parser("search", help="Search the registry")
|
||||
search_p.add_argument("query", help="Search term (name, tag, description, tool name)")
|
||||
search_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
search_p.set_defaults(func=cmd_mcp_search)
|
||||
|
||||
# ── health ──
|
||||
health_p = mcp_sub.add_parser("health", help="Check server health")
|
||||
health_p.add_argument("name", nargs="?", default=None, help="Server name (all if omitted)")
|
||||
health_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
health_p.set_defaults(func=cmd_mcp_health)
|
||||
|
||||
# ── update ──
|
||||
update_p = mcp_sub.add_parser(
|
||||
"update", help="Update installed servers or refresh the registry index"
|
||||
)
|
||||
update_p.add_argument(
|
||||
"name",
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="Server name to update (omit to update all registry servers)",
|
||||
)
|
||||
update_p.set_defaults(func=cmd_mcp_update)
|
||||
|
||||
|
||||
# ── P0 command handlers ────────────────────────────────────────────
|
||||
|
||||
|
||||
def cmd_mcp_install(args) -> int:
|
||||
"""Install a server from the registry index."""
|
||||
registry = _get_registry()
|
||||
_print_security_notice_if_first_use(registry._base)
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
|
||||
try:
|
||||
entry = registry.install(
|
||||
args.name,
|
||||
transport=args.transport,
|
||||
version=args.version,
|
||||
)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
_mark_security_notice_shown(registry._base)
|
||||
|
||||
version_str = entry.get("manifest_version", "")
|
||||
transport = entry.get("transport", "")
|
||||
print(f"✓ Installed {args.name} v{version_str} ({transport})")
|
||||
|
||||
# Prompt for credentials defined in the manifest
|
||||
manifest = entry.get("manifest", {})
|
||||
_prompt_for_missing_credentials(registry, args.name, manifest)
|
||||
|
||||
print("\nNext steps:")
|
||||
print(f" hive mcp health {args.name} Check that the server is reachable")
|
||||
print(f" hive mcp info {args.name} View server details")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_add(args) -> int:
|
||||
"""Register a local/running MCP server."""
|
||||
registry = _get_registry()
|
||||
|
||||
# Handle --from manifest.json
|
||||
if args.from_manifest:
|
||||
return _cmd_mcp_add_from_manifest(registry, args.from_manifest)
|
||||
|
||||
if not args.name:
|
||||
print(
|
||||
"Error: --name is required.\n"
|
||||
"Usage: hive mcp add --name my-server --transport http --url http://localhost:8080\n"
|
||||
" or: hive mcp add --from manifest.json",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
if not args.transport:
|
||||
print(
|
||||
f"Error: --transport is required.\n"
|
||||
f"Supported transports: stdio, http, unix, sse\n"
|
||||
f"Example: hive mcp add --name {args.name} --transport http --url http://localhost:8080",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
entry = registry.add_local(
|
||||
name=args.name,
|
||||
transport=args.transport,
|
||||
url=args.url,
|
||||
command=args.command,
|
||||
args=args.args,
|
||||
socket_path=args.socket_path,
|
||||
description=args.description,
|
||||
)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Registered {args.name} ({entry['transport']})")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_mcp_add_from_manifest(registry, manifest_path: str) -> int:
|
||||
"""Register a server from a manifest.json file."""
|
||||
path = Path(manifest_path)
|
||||
if not path.exists():
|
||||
print(
|
||||
f"Error: manifest file not found: {manifest_path}\nCheck the path and try again.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
manifest = json.loads(path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
print(
|
||||
f"Error: invalid JSON in {manifest_path}: {exc}\n"
|
||||
f"Validate with: python -m json.tool {manifest_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
name = manifest.get("name")
|
||||
if not name:
|
||||
print(
|
||||
f"Error: manifest missing 'name' field.\nAdd a 'name' field to {manifest_path}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
entry = registry.add_local(name=name, manifest=manifest)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Registered {name} from {manifest_path} ({entry['transport']})")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_remove(args) -> int:
|
||||
"""Remove an installed server."""
|
||||
registry = _get_registry()
|
||||
try:
|
||||
registry.remove(args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Removed {args.name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_enable(args) -> int:
|
||||
"""Enable a disabled server."""
|
||||
registry = _get_registry()
|
||||
try:
|
||||
registry.enable(args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Enabled {args.name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_disable(args) -> int:
|
||||
"""Disable a server without removing it."""
|
||||
registry = _get_registry()
|
||||
try:
|
||||
registry.disable(args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Disabled {args.name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_list(args) -> int:
|
||||
"""List installed or available servers."""
|
||||
registry = _get_registry()
|
||||
|
||||
if args.available:
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
entries = registry.list_available()
|
||||
if args.output_json:
|
||||
_emit_json(entries)
|
||||
else:
|
||||
_render_available_table(entries)
|
||||
else:
|
||||
entries = registry.list_installed()
|
||||
if args.output_json:
|
||||
safe_entries = []
|
||||
for entry in entries:
|
||||
safe = dict(entry)
|
||||
safe["overrides"] = _mask_overrides(safe.get("overrides", {}))
|
||||
safe_entries.append(safe)
|
||||
_emit_json(safe_entries)
|
||||
else:
|
||||
_render_installed_table(entries)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_info(args) -> int:
|
||||
"""Show full details for a server."""
|
||||
registry = _get_registry()
|
||||
server = registry.get_server(args.name)
|
||||
|
||||
if server is None:
|
||||
print(
|
||||
f"Error: server '{args.name}' is not installed.\n"
|
||||
f"Run 'hive mcp list' to see installed servers.\n"
|
||||
f"Run 'hive mcp install {args.name}' to install from registry.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
# Enrich with agent usage for both JSON and human output
|
||||
agents = _find_agents_using_server(registry, args.name)
|
||||
if agents:
|
||||
server["used_by_agents"] = agents
|
||||
|
||||
if args.output_json:
|
||||
safe = dict(server)
|
||||
safe["overrides"] = _mask_overrides(safe.get("overrides", {}))
|
||||
_emit_json(safe)
|
||||
return 0
|
||||
|
||||
manifest = server.get("manifest", {})
|
||||
overrides = _mask_overrides(server.get("overrides", {}))
|
||||
tools = manifest.get("tools", [])
|
||||
status = manifest.get("status", "community")
|
||||
hive_block = manifest.get("hive", {})
|
||||
|
||||
print(f"{server['name']}")
|
||||
print("=" * 50)
|
||||
|
||||
# Core info
|
||||
print(f" Source: {server.get('source', '')}")
|
||||
print(f" Transport: {server.get('transport', '')}")
|
||||
print(f" Version: {server.get('manifest_version', 'unknown')}")
|
||||
print(f" Trust tier: {status}")
|
||||
print(f" Enabled: {'yes' if server.get('enabled', True) else 'no'}")
|
||||
|
||||
# Description
|
||||
desc = manifest.get("description", "")
|
||||
if desc:
|
||||
print(f" Description: {desc}")
|
||||
|
||||
# Health
|
||||
health = server.get("last_health_status")
|
||||
if health:
|
||||
health_sym = {"healthy": "✓", "unhealthy": "✗"}.get(health, "●")
|
||||
print(f" Health: {health_sym} {health}")
|
||||
last_check = server.get("last_health_check_at")
|
||||
if last_check:
|
||||
print(f" Last check: {last_check}")
|
||||
last_error = server.get("last_error")
|
||||
if last_error:
|
||||
print(f" Last error: {last_error}")
|
||||
|
||||
# Tools
|
||||
if tools:
|
||||
print(f"\n Tools ({len(tools)}):")
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_name = tool.get("name", "")
|
||||
tool_desc = tool.get("description", "")
|
||||
print(f" • {tool_name}: {tool_desc}" if tool_desc else f" • {tool_name}")
|
||||
else:
|
||||
print(f" • {tool}")
|
||||
|
||||
# Overrides
|
||||
env_overrides = overrides.get("env", {})
|
||||
header_overrides = overrides.get("headers", {})
|
||||
if env_overrides or header_overrides:
|
||||
print("\n Overrides:")
|
||||
for key in env_overrides:
|
||||
print(f" env.{key} = <set>")
|
||||
for key in header_overrides:
|
||||
print(f" header.{key} = <set>")
|
||||
|
||||
# Hive block
|
||||
if hive_block:
|
||||
profiles = hive_block.get("profiles", [])
|
||||
if profiles:
|
||||
print(f"\n Profiles: {', '.join(profiles)}")
|
||||
min_ver = hive_block.get("min_version")
|
||||
if min_ver:
|
||||
print(f" Min Hive version: {min_ver}")
|
||||
|
||||
# Agent usage
|
||||
if agents:
|
||||
print("\n Used by agents:")
|
||||
for agent in agents:
|
||||
print(f" • {agent}")
|
||||
|
||||
# Timestamps
|
||||
print(f"\n Installed: {server.get('installed_at', 'unknown')}")
|
||||
print(f" Installed by: {server.get('installed_by', 'unknown')}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_config(args) -> int:
|
||||
"""Set env or header overrides for a server."""
|
||||
registry = _get_registry()
|
||||
|
||||
if not args.set_env and not args.set_header:
|
||||
# Show current config
|
||||
server = registry.get_server(args.name)
|
||||
if server is None:
|
||||
print(
|
||||
f"Error: server '{args.name}' is not installed.\n"
|
||||
f"Run 'hive mcp list' to see installed servers.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
masked = _mask_overrides(server.get("overrides", {}))
|
||||
env_o = masked.get("env", {})
|
||||
header_o = masked.get("headers", {})
|
||||
if not env_o and not header_o:
|
||||
print(f"No overrides set for {args.name}.")
|
||||
print(f"Set one with: hive mcp config {args.name} --set KEY=VALUE")
|
||||
else:
|
||||
print(f"Overrides for {args.name}:")
|
||||
for key in env_o:
|
||||
print(f" env.{key} = <set>")
|
||||
for key in header_o:
|
||||
print(f" header.{key} = <set>")
|
||||
return 0
|
||||
|
||||
try:
|
||||
if args.set_env:
|
||||
pairs = _parse_key_value_pairs(args.set_env)
|
||||
for key, value in pairs.items():
|
||||
registry.set_override(args.name, key, value, override_type="env")
|
||||
print(f"✓ Set {len(pairs)} env override(s) for {args.name}")
|
||||
|
||||
if args.set_header:
|
||||
pairs = _parse_key_value_pairs(args.set_header)
|
||||
for key, value in pairs.items():
|
||||
registry.set_override(args.name, key, value, override_type="headers")
|
||||
print(f"✓ Set {len(pairs)} header override(s) for {args.name}")
|
||||
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
# ── P1 command handlers ────────────────────────────────────────────
|
||||
|
||||
|
||||
def cmd_mcp_search(args) -> int:
|
||||
"""Search the registry index."""
|
||||
registry = _get_registry()
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
|
||||
results = registry.search(args.query)
|
||||
|
||||
if args.output_json:
|
||||
_emit_json(results)
|
||||
return 0
|
||||
|
||||
if not results:
|
||||
print(f"No servers matching '{args.query}'.")
|
||||
return 0
|
||||
|
||||
print(f"Found {len(results)} server(s) matching '{args.query}':\n")
|
||||
_render_available_table(results)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_health(args) -> int:
|
||||
"""Check server health."""
|
||||
registry = _get_registry()
|
||||
|
||||
try:
|
||||
results = registry.health_check(name=args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Single server returns a flat dict, all-servers returns name->dict
|
||||
if args.name:
|
||||
results = {args.name: results}
|
||||
|
||||
if args.output_json:
|
||||
_emit_json(results)
|
||||
return 0
|
||||
|
||||
for name, result in results.items():
|
||||
status = result.get("status", "unknown")
|
||||
tools = result.get("tools", 0)
|
||||
error = result.get("error")
|
||||
sym = {"healthy": "✓", "unhealthy": "✗"}.get(status, "●")
|
||||
|
||||
print(f" {sym} {name}: {status}", end="")
|
||||
if status == "healthy" and tools:
|
||||
print(f" ({tools} tools)")
|
||||
elif error:
|
||||
print(f"\n Error: {error}")
|
||||
else:
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_update(args) -> int:
|
||||
"""Update a single server, or refresh the index and update all registry servers."""
|
||||
registry = _get_registry()
|
||||
|
||||
if args.name:
|
||||
return _cmd_mcp_update_server(args.name, registry)
|
||||
|
||||
# Step 1: refresh the registry index
|
||||
try:
|
||||
count = registry.update_index()
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"Error: failed to update registry index: {exc}\n"
|
||||
f"Check your network connection and try again.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
print(f"✓ Registry index updated ({count} servers available)")
|
||||
|
||||
# Step 2: update all installed registry servers (skip local/pinned)
|
||||
installed = registry.list_installed()
|
||||
registry_servers = [
|
||||
s for s in installed if s.get("source") == "registry" and not s.get("pinned")
|
||||
]
|
||||
|
||||
if not registry_servers:
|
||||
return 0
|
||||
|
||||
print(f"\nUpdating {len(registry_servers)} installed server(s)...")
|
||||
errors = 0
|
||||
for server in registry_servers:
|
||||
name = server["name"]
|
||||
rc = _cmd_mcp_update_server(name, registry)
|
||||
if rc != 0:
|
||||
errors += 1
|
||||
|
||||
return 1 if errors else 0
|
||||
|
||||
|
||||
def _cmd_mcp_update_server(name: str, registry=None) -> int:
|
||||
"""Bridge: reinstall a server from the latest index.
|
||||
|
||||
This is a temporary bridge until #6355 adds proper version diffing,
|
||||
tool-signature change detection, and --dry-run support.
|
||||
"""
|
||||
if registry is None:
|
||||
registry = _get_registry()
|
||||
|
||||
server = registry.get_server(name)
|
||||
if server is None:
|
||||
print(
|
||||
f"Error: server '{name}' is not installed.\n"
|
||||
f"Run 'hive mcp install {name}' to install it.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
if server.get("source") != "registry":
|
||||
print(
|
||||
f"Error: '{name}' is a local server and cannot be updated from the registry.\n"
|
||||
f"Use 'hive mcp remove {name}' and 'hive mcp add' to re-register it.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
if server.get("pinned"):
|
||||
print(
|
||||
f"Error: '{name}' is pinned to v{server.get('manifest_version', '?')}.\n"
|
||||
f"To update a pinned server, remove and reinstall:\n"
|
||||
f" hive mcp remove {name} && hive mcp install {name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
# Refresh index, then reinstall
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
|
||||
old_version = server.get("manifest_version", "unknown")
|
||||
transport = server.get("transport")
|
||||
overrides = server.get("overrides", {})
|
||||
was_enabled = server.get("enabled", True)
|
||||
|
||||
# Save the full entry before removing so we can restore on failure
|
||||
saved_entry = dict(server)
|
||||
saved_entry.pop("name", None)
|
||||
|
||||
try:
|
||||
registry.remove(name)
|
||||
entry = registry.install(name, transport=transport)
|
||||
except ValueError as exc:
|
||||
# Restore the original entry so update doesn't become an uninstall
|
||||
data = registry._read_installed()
|
||||
data["servers"][name] = saved_entry
|
||||
registry._write_installed(data)
|
||||
print(
|
||||
f"Error: {exc}\nServer '{name}' has been restored to its previous state.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
new_version = entry.get("manifest_version", "unknown")
|
||||
|
||||
# Restore prior state from the previous installation
|
||||
for key, value in overrides.get("env", {}).items():
|
||||
registry.set_override(name, key, value, override_type="env")
|
||||
for key, value in overrides.get("headers", {}).items():
|
||||
registry.set_override(name, key, value, override_type="headers")
|
||||
if not was_enabled:
|
||||
registry.disable(name)
|
||||
|
||||
if old_version == new_version:
|
||||
print(f"✓ {name} is already at v{new_version}")
|
||||
else:
|
||||
print(f"✓ Updated {name}: v{old_version} → v{new_version}")
|
||||
|
||||
return 0
|
||||
@@ -0,0 +1,252 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_INDEX_PATH = Path.home() / ".hive" / "mcp_registry" / "cache" / "registry_index.json"
|
||||
_FIXTURE_INDEX_PATH = Path(__file__).resolve().parent / "fixtures" / "registry_index.json"
|
||||
|
||||
|
||||
def resolve_registry_servers(
|
||||
*,
|
||||
include: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
profile: str | None = None,
|
||||
max_tools: int | None = None,
|
||||
versions: dict[str, str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Resolve registry-sourced MCP servers for `mcp_registry.json` selection.
|
||||
|
||||
This function is written to be mock-friendly during early development:
|
||||
- If the real `MCPRegistry` core module is present, delegate to it.
|
||||
- Otherwise, fall back to a cached local index (`~/.hive/.../registry_index.json`)
|
||||
and then to the repo fixture index.
|
||||
"""
|
||||
|
||||
# `max_tools` is enforced by ToolRegistry. We keep it in the resolver
|
||||
# signature to match the PRD and future MCPRegistry interfaces.
|
||||
_ = max_tools
|
||||
|
||||
try:
|
||||
from framework.runner.mcp_registry import MCPRegistry # type: ignore
|
||||
|
||||
registry = MCPRegistry()
|
||||
resolved = registry.resolve_for_agent(
|
||||
include=include or [],
|
||||
tags=tags or [],
|
||||
exclude=exclude or [],
|
||||
profile=profile,
|
||||
max_tools=max_tools,
|
||||
versions=versions or {},
|
||||
)
|
||||
# Future-proof: normalize both dicts and typed objects to dicts.
|
||||
return [_normalize_server_config(x) for x in resolved]
|
||||
except ImportError:
|
||||
# Expected while #6349/#6574 is not merged locally.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("MCPRegistry resolution failed; falling back to cache/fixtures: %s", e)
|
||||
|
||||
return _resolve_from_local_index(
|
||||
include=include,
|
||||
tags=tags,
|
||||
exclude=exclude,
|
||||
profile=profile,
|
||||
versions=versions or {},
|
||||
)
|
||||
|
||||
|
||||
def _resolve_from_local_index(
|
||||
*,
|
||||
include: list[str] | None,
|
||||
tags: list[str] | None,
|
||||
exclude: list[str] | None,
|
||||
profile: str | None,
|
||||
versions: dict[str, str],
|
||||
) -> list[dict[str, Any]]:
|
||||
index = _load_index_json()
|
||||
servers = _coerce_index_servers(index)
|
||||
servers_by_name: dict[str, dict[str, Any]] = {
|
||||
s["name"]: s for s in servers if isinstance(s, dict) and "name" in s
|
||||
}
|
||||
|
||||
include_list = include or []
|
||||
tags_list = tags or []
|
||||
exclude_set = set(exclude or [])
|
||||
|
||||
def _profiles_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("profiles"), list):
|
||||
return set(entry["profiles"])
|
||||
hive = entry.get("hive")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("profiles"), list):
|
||||
return set(hive["profiles"])
|
||||
return set()
|
||||
|
||||
def _tags_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("tags"), list):
|
||||
return set(entry["tags"])
|
||||
return set()
|
||||
|
||||
def _entry_version(entry: dict[str, Any]) -> str | None:
|
||||
# Prefer flat `version`, but support a few common shapes.
|
||||
v = entry.get("version")
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
v2 = entry.get("manifest_version")
|
||||
if isinstance(v2, str):
|
||||
return v2
|
||||
hive = entry.get("manifest")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("version"), str):
|
||||
return hive["version"]
|
||||
return None
|
||||
|
||||
def _version_allows(server_name: str) -> bool:
|
||||
if server_name not in versions:
|
||||
return True
|
||||
pinned = versions[server_name]
|
||||
entry = servers_by_name.get(server_name)
|
||||
if not entry:
|
||||
return False
|
||||
return _entry_version(entry) == pinned
|
||||
|
||||
resolved_names: list[str] = []
|
||||
resolved_set: set[str] = set()
|
||||
|
||||
# 1) Include-order first
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name in servers_by_name and _version_allows(name) and name not in resolved_set:
|
||||
resolved_names.append(name)
|
||||
resolved_set.add(name)
|
||||
|
||||
# 2) Then tag/profile matches, alphabetical
|
||||
profile_candidates = set()
|
||||
if profile:
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if profile in _profiles_of(entry):
|
||||
profile_candidates.add(name)
|
||||
|
||||
tag_candidates = set()
|
||||
if tags_list:
|
||||
tags_set = set(tags_list)
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if _tags_of(entry).intersection(tags_set):
|
||||
tag_candidates.add(name)
|
||||
|
||||
tag_profile_names = sorted((profile_candidates | tag_candidates) - resolved_set)
|
||||
resolved_names.extend(tag_profile_names)
|
||||
|
||||
# Missing requested servers should warn (FR-54).
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name not in resolved_set:
|
||||
if name not in servers_by_name:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json but not found in index. "
|
||||
"Run: hive mcp install %s",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
elif name in versions:
|
||||
logger.warning(
|
||||
"Server '%s' was requested but pinned version '%s' was not found in index. "
|
||||
"Run: hive mcp update %s or change the pin in mcp_registry.json",
|
||||
name,
|
||||
versions[name],
|
||||
name,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json was not selected. "
|
||||
"Check selection filters/exclude lists.",
|
||||
name,
|
||||
)
|
||||
|
||||
resolved_configs: list[dict[str, Any]] = []
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
for name in resolved_names:
|
||||
entry = servers_by_name.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
config = entry.get("mcp_config")
|
||||
if not isinstance(config, dict):
|
||||
# Best-effort: allow a direct MCP config shape at top-level.
|
||||
config = {
|
||||
k: v
|
||||
for k, v in entry.items()
|
||||
if k
|
||||
in {
|
||||
"name",
|
||||
"transport",
|
||||
"command",
|
||||
"args",
|
||||
"env",
|
||||
"cwd",
|
||||
"url",
|
||||
"headers",
|
||||
"description",
|
||||
}
|
||||
}
|
||||
mcp_config = dict(config)
|
||||
mcp_config["name"] = name
|
||||
if mcp_config.get("transport") == "stdio":
|
||||
_absolutize_stdio_config_in_place(repo_root, mcp_config)
|
||||
resolved_configs.append(mcp_config)
|
||||
|
||||
return resolved_configs
|
||||
|
||||
|
||||
def _load_index_json() -> Any:
|
||||
if _CACHE_INDEX_PATH.exists():
|
||||
return json.loads(_CACHE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
if _FIXTURE_INDEX_PATH.exists():
|
||||
logger.info("Using local fixture index because registry cache is missing")
|
||||
return json.loads(_FIXTURE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
logger.warning("No local MCP registry index found (cache and fixture missing)")
|
||||
return {"servers": []}
|
||||
|
||||
|
||||
def _coerce_index_servers(index: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(index, list):
|
||||
return [x for x in index if isinstance(x, dict)]
|
||||
if isinstance(index, dict):
|
||||
servers = index.get("servers", [])
|
||||
if isinstance(servers, list):
|
||||
return [x for x in servers if isinstance(x, dict)]
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_server_config(raw: Any) -> dict[str, Any]:
|
||||
if isinstance(raw, dict):
|
||||
return dict(raw)
|
||||
|
||||
# Future-proof object-to-dict normalization.
|
||||
for attr in ("to_dict", "model_dump"):
|
||||
maybe = getattr(raw, attr, None)
|
||||
if callable(maybe):
|
||||
return dict(maybe())
|
||||
|
||||
return dict(getattr(raw, "__dict__", {}))
|
||||
|
||||
|
||||
def _absolutize_stdio_config_in_place(repo_root: Path, config: dict[str, Any]) -> None:
|
||||
cwd = config.get("cwd")
|
||||
if isinstance(cwd, str) and not Path(cwd).is_absolute():
|
||||
config["cwd"] = str((repo_root / cwd).resolve())
|
||||
|
||||
# We intentionally do not absolutize `args` here.
|
||||
# For stdio servers, arguments may include the script name relative to
|
||||
# `cwd` (e.g. "coder_tools_server.py" with cwd="tools"). ToolRegistry's
|
||||
# stdio resolution logic handles script path checks and platform quirks.
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Pre-load validation for agent graphs.
|
||||
|
||||
Runs structural and credential checks before MCP servers are spawned.
|
||||
Runs structural, credential, and skill-trust checks before MCP servers are spawned.
|
||||
Fails fast with actionable error messages.
|
||||
"""
|
||||
|
||||
@@ -169,6 +169,9 @@ def run_preload_validation(
|
||||
1. Graph structure (includes GCU subagent-only checks) — non-recoverable
|
||||
2. Credentials — potentially recoverable via interactive setup
|
||||
|
||||
Skill discovery and trust gating (AS-13) happen later in runner._setup()
|
||||
so they have access to agent-level skill configuration.
|
||||
|
||||
Raises PreloadValidationError for structural issues.
|
||||
Raises CredentialError for credential issues.
|
||||
"""
|
||||
|
||||
+411
-10
@@ -552,6 +552,319 @@ def get_kimi_code_token() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Antigravity subscription token helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Antigravity IDE (native macOS/Linux app) stores OAuth tokens in its
|
||||
# VSCode-style SQLite state database under the key
|
||||
# "antigravityUnifiedStateSync.oauthToken" as a base64-encoded protobuf blob.
|
||||
ANTIGRAVITY_IDE_STATE_DB = (
|
||||
Path.home()
|
||||
/ "Library"
|
||||
/ "Application Support"
|
||||
/ "Antigravity"
|
||||
/ "User"
|
||||
/ "globalStorage"
|
||||
/ "state.vscdb"
|
||||
)
|
||||
# Linux fallback for the IDE state DB
|
||||
ANTIGRAVITY_IDE_STATE_DB_LINUX = (
|
||||
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
|
||||
)
|
||||
# Antigravity credentials stored by native OAuth implementation
|
||||
ANTIGRAVITY_AUTH_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
|
||||
|
||||
ANTIGRAVITY_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_ANTIGRAVITY_TOKEN_LIFETIME_SECS = 3600 # Google access tokens expire in 1 hour
|
||||
_ANTIGRAVITY_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
|
||||
|
||||
|
||||
def _read_antigravity_ide_credentials() -> dict | None:
|
||||
"""Read credentials from the Antigravity IDE's SQLite state database.
|
||||
|
||||
The Antigravity desktop IDE (VSCode-based) stores its OAuth token as a
|
||||
base64-encoded protobuf blob in a SQLite database. The access token is
|
||||
a standard Google OAuth ``ya29.*`` bearer token.
|
||||
|
||||
Returns:
|
||||
Dict with ``accessToken`` and optionally ``refreshToken`` keys,
|
||||
plus ``_source: "ide"`` to skip file-based save on refresh.
|
||||
Returns None if the database is absent or the key is not found.
|
||||
"""
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
for db_path in (ANTIGRAVITY_IDE_STATE_DB, ANTIGRAVITY_IDE_STATE_DB_LINUX):
|
||||
if not db_path.exists():
|
||||
continue
|
||||
try:
|
||||
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
try:
|
||||
row = con.execute(
|
||||
"SELECT value FROM ItemTable WHERE key = ?",
|
||||
(_ANTIGRAVITY_IDE_STATE_DB_KEY,),
|
||||
).fetchone()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
if not row:
|
||||
continue
|
||||
|
||||
import base64
|
||||
|
||||
blob = base64.b64decode(row[0])
|
||||
|
||||
# The protobuf blob contains the access token (ya29.*) and
|
||||
# refresh token (1//*) as length-prefixed UTF-8 strings.
|
||||
# Decode the inner base64 layer and extract with regex.
|
||||
inner_b64_candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
for candidate in inner_b64_candidates:
|
||||
try:
|
||||
padded = candidate + b"=" * (-len(candidate) % 4)
|
||||
inner = base64.urlsafe_b64decode(padded)
|
||||
except Exception:
|
||||
continue
|
||||
if not access_token:
|
||||
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
access_token = m.group(0).decode("ascii")
|
||||
if not refresh_token:
|
||||
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
|
||||
if m:
|
||||
refresh_token = m.group(0).decode("ascii")
|
||||
if access_token and refresh_token:
|
||||
break
|
||||
|
||||
if access_token:
|
||||
return {
|
||||
"accounts": [
|
||||
{
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token or "",
|
||||
}
|
||||
],
|
||||
"_source": "ide",
|
||||
"_db_path": str(db_path),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _read_antigravity_credentials() -> dict | None:
|
||||
"""Read Antigravity auth data from all supported credential sources.
|
||||
|
||||
Checks in order:
|
||||
1. Antigravity IDE SQLite state database (native macOS/Linux app)
|
||||
2. Native OAuth credentials file (~/.hive/antigravity-accounts.json)
|
||||
|
||||
Returns:
|
||||
Auth data dict with an ``accounts`` list on success, None otherwise.
|
||||
"""
|
||||
# 1. Native Antigravity IDE (primary on macOS)
|
||||
ide_creds = _read_antigravity_ide_credentials()
|
||||
if ide_creds:
|
||||
return ide_creds
|
||||
|
||||
# 2. Native OAuth credentials file
|
||||
if ANTIGRAVITY_AUTH_FILE.exists():
|
||||
try:
|
||||
with open(ANTIGRAVITY_AUTH_FILE, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
accounts = data.get("accounts", [])
|
||||
if accounts and isinstance(accounts[0], dict):
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _is_antigravity_token_expired(auth_data: dict) -> bool:
|
||||
"""Check whether the Antigravity access token is expired or near expiry.
|
||||
|
||||
For IDE-sourced credentials: uses the state DB's mtime as last_refresh
|
||||
since the IDE keeps the DB fresh while it's running.
|
||||
For JSON-sourced credentials: uses the ``last_refresh`` field or file mtime.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
now = time.time()
|
||||
|
||||
if auth_data.get("_source") == "ide":
|
||||
# The IDE refreshes tokens automatically while running.
|
||||
# Use the DB file's mtime as a proxy for when the token was last updated.
|
||||
try:
|
||||
db_path = Path(auth_data.get("_db_path", str(ANTIGRAVITY_IDE_STATE_DB)))
|
||||
last_refresh: float = db_path.stat().st_mtime
|
||||
except OSError:
|
||||
return True
|
||||
expires_at = last_refresh + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
|
||||
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
|
||||
|
||||
last_refresh_val: float | str | None = auth_data.get("last_refresh")
|
||||
if last_refresh_val is None:
|
||||
try:
|
||||
last_refresh_val = ANTIGRAVITY_AUTH_FILE.stat().st_mtime
|
||||
except OSError:
|
||||
return True
|
||||
elif isinstance(last_refresh_val, str):
|
||||
try:
|
||||
last_refresh_val = datetime.fromisoformat(
|
||||
last_refresh_val.replace("Z", "+00:00")
|
||||
).timestamp()
|
||||
except (ValueError, TypeError):
|
||||
return True
|
||||
|
||||
expires_at = float(last_refresh_val) + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
|
||||
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
|
||||
|
||||
|
||||
def _refresh_antigravity_token(refresh_token: str) -> dict | None:
|
||||
"""Refresh the Antigravity access token via Google OAuth.
|
||||
|
||||
POSTs form-encoded ``grant_type=refresh_token`` to the Google token
|
||||
endpoint using Antigravity's public OAuth client ID.
|
||||
|
||||
Returns:
|
||||
Parsed response dict (containing ``access_token``) on success,
|
||||
None on any error.
|
||||
"""
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
from framework.config import get_antigravity_client_id, get_antigravity_client_secret
|
||||
|
||||
client_id = get_antigravity_client_id()
|
||||
client_secret = get_antigravity_client_secret()
|
||||
params: dict = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
params["client_secret"] = client_secret
|
||||
|
||||
data = urllib.parse.urlencode(params).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
ANTIGRAVITY_OAUTH_TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
|
||||
return json.loads(resp.read())
|
||||
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
|
||||
logger.debug("Antigravity token refresh failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _save_refreshed_antigravity_credentials(auth_data: dict, token_data: dict) -> None:
|
||||
"""Write refreshed tokens back to the Antigravity JSON credentials file.
|
||||
|
||||
Skipped for IDE-sourced credentials (the IDE manages its own DB).
|
||||
Updates ``accounts[0].accessToken`` (and ``refreshToken`` if present),
|
||||
then persists ``last_refresh`` as an ISO-8601 UTC string.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# IDE manages its own state — we do not write back to its SQLite DB
|
||||
if auth_data.get("_source") == "ide":
|
||||
return
|
||||
|
||||
try:
|
||||
accounts = auth_data.get("accounts", [])
|
||||
if not accounts:
|
||||
return
|
||||
account = accounts[0]
|
||||
account["accessToken"] = token_data["access_token"]
|
||||
if "refresh_token" in token_data:
|
||||
account["refreshToken"] = token_data["refresh_token"]
|
||||
auth_data["accounts"] = accounts
|
||||
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
|
||||
|
||||
ANTIGRAVITY_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = os.open(ANTIGRAVITY_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(auth_data, f, indent=2)
|
||||
logger.debug("Antigravity credentials refreshed and saved")
|
||||
except (OSError, KeyError) as exc:
|
||||
logger.debug("Failed to save refreshed Antigravity credentials: %s", exc)
|
||||
|
||||
|
||||
def get_antigravity_token() -> str | None:
|
||||
"""Get the OAuth access token from an Antigravity subscription.
|
||||
|
||||
Credential sources checked in order:
|
||||
1. Antigravity IDE SQLite state DB (native app, macOS/Linux)
|
||||
2. antigravity-auth CLI JSON file
|
||||
|
||||
For IDE credentials the token is read directly (the IDE refreshes it
|
||||
automatically while running). For JSON credentials an automatic OAuth
|
||||
refresh is attempted when the token is near expiry.
|
||||
|
||||
Returns:
|
||||
The ``ya29.*`` Google OAuth access token, or None if unavailable.
|
||||
"""
|
||||
auth_data = _read_antigravity_credentials()
|
||||
if not auth_data:
|
||||
return None
|
||||
|
||||
accounts = auth_data.get("accounts", [])
|
||||
if not accounts:
|
||||
return None
|
||||
account = accounts[0]
|
||||
|
||||
access_token = account.get("accessToken")
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
if not _is_antigravity_token_expired(auth_data):
|
||||
return access_token
|
||||
|
||||
# Token is expired or near expiry — attempt a refresh
|
||||
refresh_token = account.get("refreshToken")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"Antigravity token expired and no refresh token available. "
|
||||
"Re-open the Antigravity IDE to refresh, or run 'antigravity-auth accounts add'."
|
||||
)
|
||||
return access_token # return stale token; proxy may still accept it briefly
|
||||
|
||||
logger.info("Antigravity token expired or near expiry, refreshing...")
|
||||
token_data = _refresh_antigravity_token(refresh_token)
|
||||
|
||||
if token_data and "access_token" in token_data:
|
||||
_save_refreshed_antigravity_credentials(auth_data, token_data)
|
||||
return token_data["access_token"]
|
||||
|
||||
logger.warning(
|
||||
"Antigravity token refresh failed. "
|
||||
"Re-open the Antigravity IDE or run 'antigravity-auth accounts add'."
|
||||
)
|
||||
return access_token
|
||||
|
||||
|
||||
def _is_antigravity_proxy_available() -> bool:
|
||||
"""Return True if antigravity-auth serve is running on localhost:8069."""
|
||||
import socket
|
||||
|
||||
try:
|
||||
with socket.create_connection(("localhost", 8069), timeout=0.5):
|
||||
return True
|
||||
except (OSError, TimeoutError):
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about an exported agent."""
|
||||
@@ -808,6 +1121,9 @@ class AgentRunner:
|
||||
if mcp_config_path.exists():
|
||||
self._load_mcp_servers_from_config(mcp_config_path)
|
||||
|
||||
# Auto-discover registry-selected MCP servers from mcp_registry.json
|
||||
self._load_registry_mcp_servers(agent_path)
|
||||
|
||||
@staticmethod
|
||||
def _import_agent_module(agent_path: Path):
|
||||
"""Import an agent package from its directory path.
|
||||
@@ -1111,6 +1427,56 @@ class AgentRunner:
|
||||
"""Load and register MCP servers from a configuration file."""
|
||||
self._tool_registry.load_mcp_config(config_path)
|
||||
|
||||
def _load_registry_mcp_servers(self, agent_path: Path) -> None:
|
||||
"""Load and register MCP servers selected via ``mcp_registry.json``."""
|
||||
registry_json = agent_path / "mcp_registry.json"
|
||||
if registry_json.is_file():
|
||||
self._tool_registry.set_mcp_registry_agent_path(agent_path)
|
||||
else:
|
||||
self._tool_registry.set_mcp_registry_agent_path(None)
|
||||
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
try:
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
server_configs, selection_max_tools = registry.load_agent_selection(agent_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to load MCP registry servers for '%s': %s",
|
||||
agent_path.name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
if not server_configs:
|
||||
return
|
||||
|
||||
results = self._tool_registry.load_registry_servers(
|
||||
server_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
loaded = [result for result in results if result["status"] == "loaded"]
|
||||
skipped = [result for result in results if result["status"] != "loaded"]
|
||||
|
||||
logger.info(
|
||||
"Loaded %d/%d MCP registry server(s) for agent '%s'",
|
||||
len(loaded),
|
||||
len(results),
|
||||
agent_path.name,
|
||||
)
|
||||
if skipped:
|
||||
logger.info(
|
||||
"Skipped MCP registry servers for agent '%s': %s",
|
||||
agent_path.name,
|
||||
[
|
||||
{"server": result["server"], "reason": result["skipped_reason"]}
|
||||
for result in skipped
|
||||
],
|
||||
)
|
||||
|
||||
def set_approval_callback(self, callback: Callable) -> None:
|
||||
"""
|
||||
Set a callback for human-in-the-loop approval during execution.
|
||||
@@ -1141,7 +1507,10 @@ class AgentRunner:
|
||||
|
||||
# Create LLM provider
|
||||
# Uses LiteLLM which auto-detects the provider from model name
|
||||
if self.mock_mode:
|
||||
# Skip if already injected (e.g. worker agents with a pre-built LLM)
|
||||
if self._llm is not None:
|
||||
pass # LLM already configured externally
|
||||
elif self.mock_mode:
|
||||
# Use mock LLM for testing without real API calls
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
@@ -1155,6 +1524,7 @@ class AgentRunner:
|
||||
use_claude_code = llm_config.get("use_claude_code_subscription", False)
|
||||
use_codex = llm_config.get("use_codex_subscription", False)
|
||||
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
|
||||
use_antigravity = llm_config.get("use_antigravity_subscription", False)
|
||||
api_base = llm_config.get("api_base")
|
||||
|
||||
api_key = None
|
||||
@@ -1162,20 +1532,28 @@ class AgentRunner:
|
||||
# Get OAuth token from Claude Code subscription
|
||||
api_key = get_claude_code_token()
|
||||
if not api_key:
|
||||
print("Warning: Claude Code subscription configured but no token found.")
|
||||
print("Run 'claude' to authenticate, then try again.")
|
||||
logger.warning(
|
||||
"Claude Code subscription configured but no token found. "
|
||||
"Run 'claude' to authenticate, then try again."
|
||||
)
|
||||
elif use_codex:
|
||||
# Get OAuth token from Codex subscription
|
||||
api_key = get_codex_token()
|
||||
if not api_key:
|
||||
print("Warning: Codex subscription configured but no token found.")
|
||||
print("Run 'codex' to authenticate, then try again.")
|
||||
logger.warning(
|
||||
"Codex subscription configured but no token found. "
|
||||
"Run 'codex' to authenticate, then try again."
|
||||
)
|
||||
elif use_kimi_code:
|
||||
# Get API key from Kimi Code CLI config (~/.kimi/config.toml)
|
||||
api_key = get_kimi_code_token()
|
||||
if not api_key:
|
||||
print("Warning: Kimi Code subscription configured but no key found.")
|
||||
print("Run 'kimi /login' to authenticate, then try again.")
|
||||
logger.warning(
|
||||
"Kimi Code subscription configured but no key found. "
|
||||
"Run 'kimi /login' to authenticate, then try again."
|
||||
)
|
||||
elif use_antigravity:
|
||||
pass # AntigravityProvider handles credentials internally
|
||||
|
||||
if api_key and use_claude_code:
|
||||
# Use litellm's built-in Anthropic OAuth support.
|
||||
@@ -1214,6 +1592,19 @@ class AgentRunner:
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
elif use_antigravity:
|
||||
# Direct OAuth to Google's internal Cloud Code Assist gateway.
|
||||
# No local proxy required — AntigravityProvider handles token
|
||||
# refresh and Gemini-format request/response conversion natively.
|
||||
from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415
|
||||
|
||||
provider = AntigravityProvider(model=self.model)
|
||||
if not provider.has_credentials():
|
||||
print(
|
||||
"Warning: Antigravity credentials not found. "
|
||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
||||
)
|
||||
self._llm = provider
|
||||
else:
|
||||
# Local models (e.g. Ollama) don't need an API key
|
||||
if self._is_local_model(self.model):
|
||||
@@ -1245,8 +1636,12 @@ class AgentRunner:
|
||||
if api_key_env:
|
||||
os.environ[api_key_env] = api_key
|
||||
elif api_key_env:
|
||||
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
|
||||
print(f"Set it with: export {api_key_env}=your-api-key")
|
||||
logger.warning(
|
||||
"%s not set. LLM calls will fail. "
|
||||
"Set it with: export %s=your-api-key",
|
||||
api_key_env,
|
||||
api_key_env,
|
||||
)
|
||||
|
||||
# Fail fast if the agent needs an LLM but none was configured
|
||||
if self._llm is None:
|
||||
@@ -1340,7 +1735,7 @@ class AgentRunner:
|
||||
except Exception:
|
||||
pass # Best-effort — agent works without account info
|
||||
|
||||
# Skill configuration — the runtime handles discovery, loading, and
|
||||
# Skill configuration — the runtime handles discovery, loading, trust-gating and
|
||||
# prompt rasterization. The runner just builds the config.
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
@@ -1351,6 +1746,7 @@ class AgentRunner:
|
||||
skills=getattr(self, "_agent_skills", None),
|
||||
),
|
||||
project_root=self.agent_path,
|
||||
interactive=self._interactive,
|
||||
)
|
||||
|
||||
self._setup_agent_runtime(
|
||||
@@ -1381,6 +1777,8 @@ class AgentRunner:
|
||||
return "MISTRAL_API_KEY"
|
||||
elif model_lower.startswith("groq/"):
|
||||
return "GROQ_API_KEY"
|
||||
elif model_lower.startswith("openrouter/"):
|
||||
return "OPENROUTER_API_KEY"
|
||||
elif self._is_local_model(model_lower):
|
||||
return None # Local models don't need an API key
|
||||
elif model_lower.startswith("azure/"):
|
||||
@@ -1460,6 +1858,9 @@ class AgentRunner:
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus=None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
skills_manager_config=None,
|
||||
) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
|
||||
@@ -16,6 +16,8 @@ from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_INPUT_LOG_MAX_LEN = 500
|
||||
|
||||
# Per-execution context overrides. Each asyncio task (and thus each
|
||||
# concurrent graph execution) gets its own copy, so there are no races
|
||||
# when multiple ExecutionStreams run in parallel.
|
||||
@@ -54,6 +56,8 @@ class ToolRegistry:
|
||||
def __init__(self):
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
self._mcp_clients: list[Any] = [] # List of MCPClient instances
|
||||
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
|
||||
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
|
||||
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
|
||||
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
|
||||
# MCP resync tracking
|
||||
@@ -62,6 +66,8 @@ class ToolRegistry:
|
||||
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
|
||||
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
|
||||
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
|
||||
# Agent dir for re-loading registry MCP after credential resync.
|
||||
self._mcp_registry_agent_path: Path | None = None
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -243,6 +249,13 @@ class ToolRegistry:
|
||||
def _wrap_result(tool_use_id: str, result: Any) -> ToolResult:
|
||||
if isinstance(result, ToolResult):
|
||||
return result
|
||||
# MCP client returns dict with _images when image content is present
|
||||
if isinstance(result, dict) and "_images" in result:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use_id,
|
||||
content=result.get("_text", ""),
|
||||
image_content=result["_images"],
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use_id,
|
||||
content=json.dumps(result) if not isinstance(result, str) else result,
|
||||
@@ -269,6 +282,17 @@ class ToolRegistry:
|
||||
r = await result
|
||||
return _wrap_result(tool_use.id, r)
|
||||
except Exception as exc:
|
||||
inputs_str = json.dumps(tool_use.input, default=str)
|
||||
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
|
||||
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
|
||||
logger.error(
|
||||
"Async tool '%s' failed (tool_use_id=%s): %s\nInputs: %s",
|
||||
tool_use.name,
|
||||
tool_use.id,
|
||||
exc,
|
||||
inputs_str,
|
||||
exc_info=True,
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=json.dumps({"error": str(exc)}),
|
||||
@@ -279,6 +303,17 @@ class ToolRegistry:
|
||||
|
||||
return _wrap_result(tool_use.id, result)
|
||||
except Exception as e:
|
||||
inputs_str = json.dumps(tool_use.input, default=str)
|
||||
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
|
||||
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
|
||||
logger.error(
|
||||
"Tool '%s' execution failed for tool_use_id=%s: %s\nInputs: %s",
|
||||
tool_use.name,
|
||||
tool_use.id,
|
||||
e,
|
||||
inputs_str,
|
||||
exc_info=True,
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=json.dumps({"error": str(e)}),
|
||||
@@ -453,33 +488,129 @@ class ToolRegistry:
|
||||
# Treat top-level keys as server names
|
||||
server_list = [{"name": name, **cfg} for name, cfg in config.items()]
|
||||
|
||||
for server_config in server_list:
|
||||
server_config = self._resolve_mcp_server_config(server_config, base_dir)
|
||||
for _attempt in range(2):
|
||||
try:
|
||||
self.register_mcp_server(server_config)
|
||||
break
|
||||
except Exception as e:
|
||||
name = server_config.get("name", "unknown")
|
||||
if _attempt == 0:
|
||||
logger.warning(
|
||||
"MCP server '%s' failed to register, retrying in 2s: %s",
|
||||
name,
|
||||
e,
|
||||
)
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
else:
|
||||
logger.warning("MCP server '%s' failed after retry: %s", name, e)
|
||||
resolved_server_list = [
|
||||
self._resolve_mcp_server_config(server_config, base_dir)
|
||||
for server_config in server_list
|
||||
]
|
||||
# Ordered first-wins for duplicate tool names across servers; keep tools.py tools.
|
||||
self.load_registry_servers(
|
||||
resolved_server_list,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=False,
|
||||
)
|
||||
|
||||
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
|
||||
self._mcp_cred_snapshot = self._snapshot_credentials()
|
||||
self._mcp_aden_key_snapshot = os.environ.get("ADEN_API_KEY")
|
||||
|
||||
def _register_mcp_server_with_retry(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
*,
|
||||
preserve_existing_tools: bool = True,
|
||||
tool_cap: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> tuple[bool, int, str | None]:
|
||||
"""Register a single MCP server with one retry for transient failures."""
|
||||
name = server_config.get("name", "unknown")
|
||||
last_error: str | None = None
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
count = self.register_mcp_server(
|
||||
server_config,
|
||||
preserve_existing_tools=preserve_existing_tools,
|
||||
tool_cap=tool_cap,
|
||||
log_collisions=log_collisions,
|
||||
)
|
||||
if count > 0:
|
||||
return True, count, None
|
||||
last_error = "registered 0 tools"
|
||||
except Exception as exc:
|
||||
last_error = str(exc)
|
||||
|
||||
if attempt == 0:
|
||||
logger.warning(
|
||||
"MCP server '%s' failed to register, retrying in 2s: %s",
|
||||
name,
|
||||
last_error,
|
||||
)
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
else:
|
||||
logger.warning("MCP server '%s' failed after retry: %s", name, last_error)
|
||||
|
||||
return False, 0, last_error
|
||||
|
||||
def load_registry_servers(
|
||||
self,
|
||||
server_list: list[dict[str, Any]],
|
||||
*,
|
||||
log_summary: bool = True,
|
||||
preserve_existing_tools: bool = True,
|
||||
max_tools: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Register MCP servers from a resolved config list (registry and/or static).
|
||||
|
||||
``preserve_existing_tools`` enforces first-wins tool names (FR-100): later
|
||||
servers skip names already taken— including tools from ``mcp_servers.json``
|
||||
or ``tools.py`` when those were loaded first.
|
||||
|
||||
``max_tools`` caps how many *new* tool names are registered across this batch
|
||||
(collisions do not consume the cap). When ``log_collisions`` is True, skipped
|
||||
duplicate names emit a warning (FR-101).
|
||||
"""
|
||||
results: list[dict[str, Any]] = []
|
||||
tools_added_batch = 0
|
||||
|
||||
for server_config in server_list:
|
||||
remaining: int | None = None
|
||||
if max_tools is not None:
|
||||
remaining = max_tools - tools_added_batch
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
name = server_config.get("name", "unknown")
|
||||
success, tools_loaded, error = self._register_mcp_server_with_retry(
|
||||
server_config,
|
||||
preserve_existing_tools=preserve_existing_tools,
|
||||
tool_cap=remaining,
|
||||
log_collisions=log_collisions,
|
||||
)
|
||||
tools_added_batch += tools_loaded
|
||||
result = {
|
||||
"server": name,
|
||||
"status": "loaded" if success else "skipped",
|
||||
"tools_loaded": tools_loaded,
|
||||
"skipped_reason": None if success else (error or "unknown error"),
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
if log_summary:
|
||||
logger.info(
|
||||
"MCP registry server resolution",
|
||||
extra={
|
||||
"event": "mcp_registry_server_resolution",
|
||||
"server": result["server"],
|
||||
"status": result["status"],
|
||||
"tools_loaded": result["tools_loaded"],
|
||||
"skipped_reason": result["skipped_reason"],
|
||||
},
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
use_connection_manager: bool = True,
|
||||
*,
|
||||
preserve_existing_tools: bool = True,
|
||||
tool_cap: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Register an MCP server and discover its tools.
|
||||
@@ -495,12 +626,17 @@ class ToolRegistry:
|
||||
- url: Server URL (for http)
|
||||
- headers: HTTP headers (for http)
|
||||
- description: Server description (optional)
|
||||
use_connection_manager: When True, reuse a shared client keyed by server name
|
||||
preserve_existing_tools: If True, do not replace tools already in the registry.
|
||||
tool_cap: Max tools to newly register from this server (None = unlimited).
|
||||
log_collisions: If True, log when this server skips a tool name already taken.
|
||||
|
||||
Returns:
|
||||
Number of tools registered from this server
|
||||
"""
|
||||
try:
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
# Build config object
|
||||
config = MCPServerConfig(
|
||||
@@ -512,15 +648,23 @@ class ToolRegistry:
|
||||
cwd=server_config.get("cwd"),
|
||||
url=server_config.get("url"),
|
||||
headers=server_config.get("headers", {}),
|
||||
socket_path=server_config.get("socket_path"),
|
||||
description=server_config.get("description", ""),
|
||||
)
|
||||
|
||||
# Create and connect client
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
if use_connection_manager:
|
||||
client = MCPConnectionManager.get_instance().acquire(config)
|
||||
else:
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
|
||||
# Store client for cleanup
|
||||
self._mcp_clients.append(client)
|
||||
client_id = id(client)
|
||||
self._mcp_client_servers[client_id] = config.name
|
||||
if use_connection_manager:
|
||||
self._mcp_managed_clients.add(client_id)
|
||||
|
||||
# Register each tool
|
||||
server_name = server_config["name"]
|
||||
@@ -528,6 +672,23 @@ class ToolRegistry:
|
||||
self._mcp_server_tools[server_name] = set()
|
||||
count = 0
|
||||
for mcp_tool in client.list_tools():
|
||||
if tool_cap is not None and count >= tool_cap:
|
||||
break
|
||||
|
||||
if preserve_existing_tools and mcp_tool.name in self._tools:
|
||||
if log_collisions:
|
||||
origin_server = (
|
||||
self._find_mcp_origin_server_for_tool(mcp_tool.name) or "<existing>"
|
||||
)
|
||||
logger.warning(
|
||||
"MCP tool '%s' from '%s' shadowed by '%s' (loaded first)",
|
||||
mcp_tool.name,
|
||||
server_name,
|
||||
origin_server,
|
||||
)
|
||||
# Skip registration; do not update MCP tool bookkeeping for this server.
|
||||
continue
|
||||
|
||||
# Convert MCP tool to framework Tool (strips context params from LLM schema)
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
@@ -560,14 +721,25 @@ class ToolRegistry:
|
||||
}
|
||||
merged_inputs = {**clean_inputs, **filtered_context}
|
||||
result = client_ref.call_tool(tool_name, merged_inputs)
|
||||
# MCP tools return content array, extract the result
|
||||
# MCP client already extracts content (returns str
|
||||
# or {"_text": ..., "_images": ...} for image results).
|
||||
# Handle legacy list format from HTTP transport.
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
if isinstance(result[0], dict) and "text" in result[0]:
|
||||
return result[0]["text"]
|
||||
return result[0]
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool '{tool_name}' execution failed: {e}")
|
||||
inputs_str = json.dumps(inputs, default=str)
|
||||
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
|
||||
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
|
||||
logger.error(
|
||||
"MCP tool '%s' execution failed: %s\nInputs: %s",
|
||||
tool_name,
|
||||
e,
|
||||
inputs_str,
|
||||
exc_info=True,
|
||||
)
|
||||
return {"error": str(e)}
|
||||
|
||||
return executor
|
||||
@@ -582,11 +754,27 @@ class ToolRegistry:
|
||||
self._mcp_server_tools[server_name].add(mcp_tool.name)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Registered {count} tools from MCP server '{config.name}'")
|
||||
logger.info(
|
||||
"MCP Registry Load",
|
||||
extra={
|
||||
"server": config.name,
|
||||
"status": "success",
|
||||
"tools_loaded": count,
|
||||
"skipped_reason": None,
|
||||
},
|
||||
)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register MCP server: {e}")
|
||||
logger.error(
|
||||
"MCP Registry Load",
|
||||
extra={
|
||||
"server": server_config.get("name", "unknown"),
|
||||
"status": "failed",
|
||||
"tools_loaded": 0,
|
||||
"skipped_reason": str(e),
|
||||
},
|
||||
)
|
||||
if "Connection closed" in str(e) and os.name == "nt":
|
||||
logger.debug(
|
||||
"On Windows, check that the MCP subprocess starts (e.g. uv in PATH, "
|
||||
@@ -594,6 +782,12 @@ class ToolRegistry:
|
||||
)
|
||||
return 0
|
||||
|
||||
def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None:
|
||||
for server_name, tool_names in self._mcp_server_tools.items():
|
||||
if tool_name in tool_names:
|
||||
return server_name
|
||||
return None
|
||||
|
||||
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
|
||||
"""
|
||||
Convert an MCP tool to a framework Tool.
|
||||
@@ -681,6 +875,37 @@ class ToolRegistry:
|
||||
# MCP credential resync
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_mcp_registry_agent_path(self, agent_path: Path | None) -> None:
|
||||
"""Remember agent dir so registry MCP servers reload after credential resync."""
|
||||
self._mcp_registry_agent_path = None if agent_path is None else Path(agent_path)
|
||||
|
||||
def reload_registry_mcp_servers_after_resync(self) -> None:
|
||||
"""Re-run ``mcp_registry.json`` resolution and register servers (post-resync)."""
|
||||
if self._mcp_registry_agent_path is None:
|
||||
return
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
try:
|
||||
reg = MCPRegistry()
|
||||
reg.initialize()
|
||||
configs, selection_max_tools = reg.load_agent_selection(self._mcp_registry_agent_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to reload MCP registry servers after resync for '%s': %s",
|
||||
self._mcp_registry_agent_path.name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
if not configs:
|
||||
return
|
||||
self.load_registry_servers(
|
||||
configs,
|
||||
log_summary=True,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
|
||||
def _snapshot_credentials(self) -> set[str]:
|
||||
"""Return the set of credential filenames currently on disk."""
|
||||
try:
|
||||
@@ -720,32 +945,46 @@ class ToolRegistry:
|
||||
logger.info("%s — resyncing MCP servers", reason)
|
||||
|
||||
# 1. Disconnect existing MCP clients
|
||||
for client in self._mcp_clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client during resync: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._cleanup_mcp_clients("during resync")
|
||||
|
||||
# 2. Remove MCP-registered tools
|
||||
for name in self._mcp_tool_names:
|
||||
self._tools.pop(name, None)
|
||||
self._mcp_tool_names.clear()
|
||||
self._mcp_server_tools.clear()
|
||||
|
||||
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
|
||||
self.load_mcp_config(self._mcp_config_path)
|
||||
if self._mcp_registry_agent_path is not None:
|
||||
self.reload_registry_mcp_servers_after_resync()
|
||||
|
||||
logger.info("MCP server resync complete")
|
||||
return True
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all MCP client connections."""
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
def _cleanup_mcp_clients(self, context: str = "") -> None:
|
||||
"""Disconnect or release all tracked MCP clients for this registry."""
|
||||
if context:
|
||||
context = f" {context}"
|
||||
|
||||
for client in self._mcp_clients:
|
||||
client_id = id(client)
|
||||
server_name = self._mcp_client_servers.get(client_id, client.config.name)
|
||||
try:
|
||||
client.disconnect()
|
||||
if client_id in self._mcp_managed_clients:
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
MCPConnectionManager.get_instance().release(server_name)
|
||||
else:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client: {e}")
|
||||
logger.warning(f"Error disconnecting MCP client{context}: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._mcp_client_servers.clear()
|
||||
self._mcp_managed_clients.clear()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup."""
|
||||
|
||||
@@ -137,6 +137,7 @@ class AgentRuntime:
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent runtime.
|
||||
@@ -158,6 +159,9 @@ class AgentRuntime:
|
||||
event_bus: Optional external EventBus. If provided, the runtime shares
|
||||
this bus instead of creating its own. Used by SessionManager to
|
||||
share a single bus between queen, worker, and judge.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
@@ -195,6 +199,10 @@ class AgentRuntime:
|
||||
self._skills_manager = SkillsManager()
|
||||
self._skills_manager.load()
|
||||
|
||||
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
|
||||
self.context_warn_ratio: float | None = self._skills_manager.context_warn_ratio
|
||||
self.batch_init_nudge: str | None = self._skills_manager.batch_init_nudge
|
||||
|
||||
# Primary graph identity
|
||||
self._graph_id: str = graph_id or "primary"
|
||||
|
||||
@@ -341,6 +349,9 @@ class AgentRuntime:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
context_warn_ratio=self.context_warn_ratio,
|
||||
batch_init_nudge=self.batch_init_nudge,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
@@ -977,6 +988,7 @@ class AgentRuntime:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
if self._running:
|
||||
await stream.start()
|
||||
@@ -1466,6 +1478,7 @@ class AgentRuntime:
|
||||
graph_id: str | None = None,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Inject user input into a running client-facing node.
|
||||
|
||||
@@ -1478,6 +1491,8 @@ class AgentRuntime:
|
||||
graph_id: Optional graph to search first (defaults to active graph)
|
||||
is_client_input: True when the message originates from a real
|
||||
human user (e.g. /chat endpoint), False for external events.
|
||||
image_content: Optional list of image content blocks (OpenAI
|
||||
image_url format) to include alongside the text.
|
||||
|
||||
Returns:
|
||||
True if input was delivered, False if no matching node found
|
||||
@@ -1489,7 +1504,9 @@ class AgentRuntime:
|
||||
target = graph_id or self._active_graph_id
|
||||
if target in self._graphs:
|
||||
for stream in self._graphs[target].streams.values():
|
||||
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
|
||||
if await stream.inject_input(
|
||||
node_id, content, is_client_input=is_client_input, image_content=image_content
|
||||
):
|
||||
return True
|
||||
|
||||
# Then search all other graphs
|
||||
@@ -1497,7 +1514,9 @@ class AgentRuntime:
|
||||
if gid == target:
|
||||
continue
|
||||
for stream in reg.streams.values():
|
||||
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
|
||||
if await stream.inject_input(
|
||||
node_id, content, is_client_input=is_client_input, image_content=image_content
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1760,6 +1779,7 @@ def create_agent_runtime(
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
) -> AgentRuntime:
|
||||
"""
|
||||
Create and configure an AgentRuntime with entry points.
|
||||
@@ -1786,6 +1806,9 @@ def create_agent_runtime(
|
||||
accounts_data: Raw account data for per-node prompt generation.
|
||||
tool_provider_map: Tool name to provider name mapping for account routing.
|
||||
event_bus: Optional external EventBus to share with other components.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt.
|
||||
protocols_prompt: Default skill operational protocols for system prompt.
|
||||
skill_dirs: Skill base directories for Tier 3 resource access.
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
@@ -1819,6 +1842,7 @@ def create_agent_runtime(
|
||||
skills_manager_config=skills_manager_config,
|
||||
skills_catalog_prompt=skills_catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
skill_dirs=skill_dirs,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Any
|
||||
from framework.observability import set_trace_context
|
||||
from framework.schemas.decision import Decision, DecisionType, Option, Outcome
|
||||
from framework.schemas.run import Run, RunStatus
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,7 +62,7 @@ class Runtime:
|
||||
logger.warning(f"Storage path does not exist, creating: {path}")
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.storage = FileStorage(storage_path)
|
||||
self.storage = ConcurrentStorage(storage_path)
|
||||
self._current_run: Run | None = None
|
||||
self._current_node: str = "unknown"
|
||||
|
||||
@@ -132,8 +132,8 @@ class Runtime:
|
||||
self._current_run.output_data = output_data or {}
|
||||
self._current_run.complete(status, narrative)
|
||||
|
||||
# Save to storage
|
||||
self.storage.save_run(self._current_run)
|
||||
# Save to storage (sync — Runtime methods are not async)
|
||||
self.storage.save_run_sync(self._current_run)
|
||||
self._current_run = None
|
||||
|
||||
def set_node(self, node_id: str) -> None:
|
||||
|
||||
@@ -117,6 +117,7 @@ class EventType(StrEnum):
|
||||
|
||||
# Context management
|
||||
CONTEXT_COMPACTED = "context_compacted"
|
||||
CONTEXT_USAGE_UPDATED = "context_usage_updated"
|
||||
|
||||
# External triggers
|
||||
WEBHOOK_RECEIVED = "webhook_received"
|
||||
@@ -534,8 +535,8 @@ class EventBus:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Handler error for {event.type}: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"Handler error for {event.type}")
|
||||
|
||||
# Run all handlers concurrently
|
||||
await asyncio.gather(*[run_handler(h) for h in handlers], return_exceptions=True)
|
||||
|
||||
@@ -188,6 +188,9 @@ class ExecutionStream:
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
context_warn_ratio: float | None = None,
|
||||
batch_init_nudge: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -213,6 +216,9 @@ class ExecutionStream:
|
||||
tool_provider_map: Tool name to provider name mapping for account routing
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
context_warn_ratio: Token usage ratio to trigger DS-13 preservation warning
|
||||
batch_init_nudge: System prompt nudge for DS-12 batch auto-detection
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
@@ -236,6 +242,9 @@ class ExecutionStream:
|
||||
self._tool_provider_map = tool_provider_map
|
||||
self._skills_catalog_prompt = skills_catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
self._skill_dirs: list[str] = skill_dirs or []
|
||||
self._context_warn_ratio: float | None = context_warn_ratio
|
||||
self._batch_init_nudge: str | None = batch_init_nudge
|
||||
|
||||
_es_logger = logging.getLogger(__name__)
|
||||
if protocols_prompt:
|
||||
@@ -430,6 +439,7 @@ class ExecutionStream:
|
||||
content: str,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Inject user input into a running client-facing EventLoopNode.
|
||||
|
||||
@@ -441,7 +451,9 @@ class ExecutionStream:
|
||||
for executor in self._active_executors.values():
|
||||
node = executor.node_registry.get(node_id)
|
||||
if node is not None and hasattr(node, "inject_event"):
|
||||
await node.inject_event(content, is_client_input=is_client_input)
|
||||
await node.inject_event(
|
||||
content, is_client_input=is_client_input, image_content=image_content
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -696,6 +708,9 @@ class ExecutionStream:
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self._skills_catalog_prompt,
|
||||
protocols_prompt=self._protocols_prompt,
|
||||
skill_dirs=self._skill_dirs,
|
||||
context_warn_ratio=self._context_warn_ratio,
|
||||
batch_init_nudge=self._batch_init_nudge,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
@@ -954,7 +969,10 @@ class ExecutionStream:
|
||||
return
|
||||
import json as _json
|
||||
|
||||
session_dir = self._session_store.get_session_path(execution_id)
|
||||
try:
|
||||
session_dir = self._session_store.get_session_path(execution_id)
|
||||
except ValueError:
|
||||
return
|
||||
runs_file = session_dir / "runs.jsonl"
|
||||
now = datetime.now()
|
||||
record = {
|
||||
|
||||
@@ -8,6 +8,7 @@ write. Errors are silently swallowed — this must never break the agent.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
@@ -47,6 +48,9 @@ def log_llm_turn(
|
||||
Never raises.
|
||||
"""
|
||||
try:
|
||||
# Skip logging during test runs to avoid polluting real logs.
|
||||
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("HIVE_DISABLE_LLM_LOGS"):
|
||||
return
|
||||
global _log_file, _log_ready # noqa: PLW0603
|
||||
if not _log_ready:
|
||||
_log_file = _open_log()
|
||||
|
||||
@@ -62,6 +62,7 @@ async def create_queen(
|
||||
from framework.agents.queen.nodes.thinking_hook import select_expert_persona
|
||||
from framework.graph.event_loop_node import HookContext, HookResult
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
@@ -86,6 +87,23 @@ async def create_queen(
|
||||
except Exception:
|
||||
logger.warning("Queen: MCP config failed to load", exc_info=True)
|
||||
|
||||
try:
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(queen_pkg_dir)
|
||||
if registry_configs:
|
||||
results = queen_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
logger.info("Queen: loaded MCP registry servers: %s", results)
|
||||
except Exception:
|
||||
logger.warning("Queen: MCP registry config failed to load", exc_info=True)
|
||||
|
||||
# ---- Phase state --------------------------------------------------
|
||||
initial_phase = "staging" if worker_identity else "planning"
|
||||
phase_state = QueenPhaseState(phase=initial_phase, event_bus=session.event_bus)
|
||||
@@ -221,12 +239,18 @@ async def create_queen(
|
||||
)
|
||||
|
||||
# ---- Default skill protocols -------------------------------------
|
||||
_queen_skill_dirs: list[str] = []
|
||||
try:
|
||||
from framework.skills.manager import SkillsManager
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
|
||||
_queen_skills_mgr = SkillsManager()
|
||||
# Pass project_root so user-scope skills (~/.hive/skills/, ~/.agents/skills/)
|
||||
# are discovered. Queen has no agent-specific project root, so we use its
|
||||
# own directory — the value just needs to be non-None to enable user-scope scanning.
|
||||
_queen_skills_mgr = SkillsManager(SkillsManagerConfig(project_root=Path(__file__).parent))
|
||||
_queen_skills_mgr.load()
|
||||
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
|
||||
phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt
|
||||
_queen_skill_dirs = _queen_skills_mgr.allowlisted_dirs
|
||||
except Exception:
|
||||
logger.debug("Queen skill loading failed (non-fatal)", exc_info=True)
|
||||
|
||||
@@ -291,6 +315,7 @@ async def create_queen(
|
||||
dynamic_tools_provider=phase_state.get_current_tools,
|
||||
dynamic_prompt_provider=phase_state.get_current_prompt,
|
||||
iteration_metadata_provider=lambda: {"phase": phase_state.phase},
|
||||
skill_dirs=_queen_skill_dirs,
|
||||
)
|
||||
session.queen_executor = executor
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.NODE_RETRY,
|
||||
EventType.NODE_TOOL_DOOM_LOOP,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.CONTEXT_USAGE_UPDATED,
|
||||
EventType.WORKER_LOADED,
|
||||
EventType.CREDENTIALS_REQUIRED,
|
||||
EventType.SUBAGENT_REPORT,
|
||||
|
||||
@@ -108,7 +108,10 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
The input box is permanently connected to the queen agent.
|
||||
Worker input is handled separately via /worker-input.
|
||||
|
||||
Body: {"message": "hello"}
|
||||
Body: {"message": "hello", "images": [{"type": "image_url", "image_url": {"url": "data:..."}}]}
|
||||
|
||||
The optional ``images`` field accepts a list of OpenAI-format image_url
|
||||
content blocks. The frontend encodes images as base64 data URIs.
|
||||
"""
|
||||
session, err = resolve_session(request)
|
||||
if err:
|
||||
@@ -116,15 +119,16 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
|
||||
body = await request.json()
|
||||
message = body.get("message", "")
|
||||
image_content = body.get("images") or None # list[dict] | None
|
||||
|
||||
if not message:
|
||||
if not message and not image_content:
|
||||
return web.json_response({"error": "message is required"}, status=400)
|
||||
|
||||
queen_executor = session.queen_executor
|
||||
if queen_executor is not None:
|
||||
node = queen_executor.node_registry.get("queen")
|
||||
if node is not None and hasattr(node, "inject_event"):
|
||||
await node.inject_event(message, is_client_input=True)
|
||||
await node.inject_event(message, is_client_input=True, image_content=image_content)
|
||||
# Publish to EventBus so the session event log captures user messages
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
@@ -134,7 +138,10 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
stream_id="queen",
|
||||
node_id="queen",
|
||||
execution_id=session.id,
|
||||
data={"content": message},
|
||||
data={
|
||||
"content": message,
|
||||
"image_count": len(image_content) if image_content else 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
return web.json_response(
|
||||
|
||||
@@ -11,7 +11,6 @@ Session-primary routes:
|
||||
- GET /api/sessions/{session_id}/entry-points — list entry points
|
||||
- PATCH /api/sessions/{session_id}/triggers/{id} — update trigger task
|
||||
- GET /api/sessions/{session_id}/graphs — list graph IDs
|
||||
- GET /api/sessions/{session_id}/queen-messages — queen conversation history
|
||||
- GET /api/sessions/{session_id}/events/history — persisted eventbus log (for replay)
|
||||
|
||||
Worker session browsing (persisted execution runs on disk):
|
||||
@@ -29,6 +28,8 @@ import contextlib
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -52,8 +53,11 @@ def _get_manager(request: web.Request) -> SessionManager:
|
||||
|
||||
def _session_to_live_dict(session) -> dict:
|
||||
"""Serialize a live Session to the session-primary JSON shape."""
|
||||
from framework.llm.capabilities import supports_image_tool_results
|
||||
|
||||
info = session.worker_info
|
||||
phase_state = getattr(session, "phase_state", None)
|
||||
queen_model: str = getattr(getattr(session, "runner", None), "model", "") or ""
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"worker_id": session.worker_id,
|
||||
@@ -69,6 +73,7 @@ def _session_to_live_dict(session) -> dict:
|
||||
"queen_phase": phase_state.phase
|
||||
if phase_state
|
||||
else ("staging" if session.worker_runtime else "planning"),
|
||||
"queen_supports_images": supports_image_tool_results(queen_model) if queen_model else True,
|
||||
}
|
||||
|
||||
|
||||
@@ -862,60 +867,6 @@ async def handle_messages(request: web.Request) -> web.Response:
|
||||
return web.json_response({"messages": all_messages})
|
||||
|
||||
|
||||
async def handle_queen_messages(request: web.Request) -> web.Response:
|
||||
"""GET /api/sessions/{session_id}/queen-messages — get queen conversation.
|
||||
|
||||
Reads directly from disk so it works for both live sessions and cold
|
||||
(post-server-restart) sessions — no live session required.
|
||||
"""
|
||||
session_id = request.match_info["session_id"]
|
||||
|
||||
queen_dir = Path.home() / ".hive" / "queen" / "session" / session_id
|
||||
convs_dir = queen_dir / "conversations"
|
||||
if not convs_dir.exists():
|
||||
return web.json_response({"messages": [], "session_id": session_id})
|
||||
|
||||
all_messages: list[dict] = []
|
||||
|
||||
def _read_parts(parts_dir: Path, node_id: str) -> None:
|
||||
if not parts_dir.exists():
|
||||
return
|
||||
for part_file in sorted(parts_dir.iterdir()):
|
||||
if part_file.suffix != ".json":
|
||||
continue
|
||||
try:
|
||||
part = json.loads(part_file.read_text(encoding="utf-8"))
|
||||
part["_node_id"] = node_id
|
||||
# Use file mtime as created_at so frontend can order
|
||||
# queen and worker messages chronologically.
|
||||
part.setdefault("created_at", part_file.stat().st_mtime)
|
||||
all_messages.append(part)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
# Flat layout: conversations/parts/*.json
|
||||
_read_parts(convs_dir / "parts", "queen")
|
||||
|
||||
# Node-based layout: conversations/<node_id>/parts/*.json
|
||||
for node_dir in convs_dir.iterdir():
|
||||
if not node_dir.is_dir() or node_dir.name == "parts":
|
||||
continue
|
||||
_read_parts(node_dir / "parts", node_dir.name)
|
||||
|
||||
all_messages.sort(key=lambda m: m.get("created_at", m.get("seq", 0)))
|
||||
|
||||
# Filter to client-facing messages only
|
||||
all_messages = [
|
||||
m
|
||||
for m in all_messages
|
||||
if not m.get("is_transition_marker")
|
||||
and m["role"] != "tool"
|
||||
and not (m["role"] == "assistant" and m.get("tool_calls"))
|
||||
]
|
||||
|
||||
return web.json_response({"messages": all_messages, "session_id": session_id})
|
||||
|
||||
|
||||
async def handle_session_events_history(request: web.Request) -> web.Response:
|
||||
"""GET /api/sessions/{session_id}/events/history — persisted eventbus log.
|
||||
|
||||
@@ -1033,6 +984,29 @@ async def handle_discover(request: web.Request) -> web.Response:
|
||||
return web.json_response(result)
|
||||
|
||||
|
||||
async def handle_reveal_session_folder(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/reveal — open session data folder in the OS file manager."""
|
||||
manager: SessionManager = request.app["manager"]
|
||||
session_id = request.match_info["session_id"]
|
||||
|
||||
session = manager.get_session(session_id)
|
||||
storage_session_id = (session.queen_resume_from or session.id) if session else session_id
|
||||
folder = Path.home() / ".hive" / "queen" / "session" / storage_session_id
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.Popen(["open", str(folder)])
|
||||
elif sys.platform == "win32":
|
||||
subprocess.Popen(["explorer", str(folder)])
|
||||
else:
|
||||
subprocess.Popen(["xdg-open", str(folder)])
|
||||
except Exception as exc:
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
return web.json_response({"path": str(folder)})
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Route registration
|
||||
# ------------------------------------------------------------------
|
||||
@@ -1057,13 +1031,14 @@ def register_routes(app: web.Application) -> None:
|
||||
app.router.add_delete("/api/sessions/{session_id}/worker", handle_unload_worker)
|
||||
|
||||
# Session info
|
||||
app.router.add_post("/api/sessions/{session_id}/reveal", handle_reveal_session_folder)
|
||||
app.router.add_get("/api/sessions/{session_id}/stats", handle_session_stats)
|
||||
app.router.add_get("/api/sessions/{session_id}/entry-points", handle_session_entry_points)
|
||||
app.router.add_patch(
|
||||
"/api/sessions/{session_id}/triggers/{trigger_id}", handle_update_trigger_task
|
||||
)
|
||||
app.router.add_get("/api/sessions/{session_id}/graphs", handle_session_graphs)
|
||||
app.router.add_get("/api/sessions/{session_id}/queen-messages", handle_queen_messages)
|
||||
|
||||
app.router.add_get("/api/sessions/{session_id}/events/history", handle_session_events_history)
|
||||
|
||||
# Worker session browsing (session-primary)
|
||||
|
||||
@@ -96,8 +96,7 @@ class SessionManager:
|
||||
|
||||
Internal helper — use create_session() or create_session_with_worker().
|
||||
"""
|
||||
from framework.config import RuntimeConfig
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.config import RuntimeConfig, get_hive_config
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -111,12 +110,20 @@ class SessionManager:
|
||||
rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model)
|
||||
|
||||
# Session owns these — shared with queen and worker
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
llm_config = get_hive_config().get("llm", {})
|
||||
if llm_config.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
llm = AntigravityProvider(model=rc.model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
event_bus = EventBus()
|
||||
|
||||
session = Session(
|
||||
@@ -287,7 +294,17 @@ class SessionManager:
|
||||
try:
|
||||
# Blocking I/O — load in executor
|
||||
loop = asyncio.get_running_loop()
|
||||
resolved_model = model or self._model
|
||||
|
||||
# Prioritize: explicit model arg > worker-specific model > session default
|
||||
from framework.config import (
|
||||
get_preferred_worker_model,
|
||||
get_worker_api_base,
|
||||
get_worker_api_key,
|
||||
get_worker_llm_extra_kwargs,
|
||||
)
|
||||
|
||||
worker_model = get_preferred_worker_model()
|
||||
resolved_model = model or worker_model or self._model
|
||||
runner = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: AgentRunner.load(
|
||||
@@ -299,6 +316,30 @@ class SessionManager:
|
||||
),
|
||||
)
|
||||
|
||||
# If a worker-specific model is configured, build an LLM provider
|
||||
# with the correct worker credentials so _setup() doesn't fall back
|
||||
# to the queen's llm config (which may be a different provider).
|
||||
if worker_model and not model:
|
||||
from framework.config import get_hive_config
|
||||
|
||||
worker_llm_cfg = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm_cfg.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
runner._llm = AntigravityProvider(model=resolved_model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
worker_api_key = get_worker_api_key()
|
||||
worker_api_base = get_worker_api_base()
|
||||
worker_extra = get_worker_llm_extra_kwargs()
|
||||
runner._llm = LiteLLMProvider(
|
||||
model=resolved_model,
|
||||
api_key=worker_api_key,
|
||||
api_base=worker_api_base,
|
||||
**worker_extra,
|
||||
)
|
||||
|
||||
# Setup with session's event bus
|
||||
if runner._agent_runtime is None:
|
||||
await loop.run_in_executor(
|
||||
@@ -793,10 +834,11 @@ class SessionManager:
|
||||
exec_id = event.execution_id
|
||||
|
||||
if event.type == _ET.EXECUTION_STARTED:
|
||||
# New run on this execution_id — reset cooldown so the first
|
||||
# iteration always produces a mid-run snapshot.
|
||||
# New run on this execution_id — start the cooldown timer so
|
||||
# mid-run snapshots don't fire immediately at session start.
|
||||
# The first snapshot will happen after _DIGEST_COOLDOWN seconds.
|
||||
if exec_id:
|
||||
_last_digest.pop(exec_id, None)
|
||||
_last_digest[exec_id] = _time.monotonic()
|
||||
|
||||
elif event.type in (
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
@@ -923,6 +965,7 @@ class SessionManager:
|
||||
# then use max+1 as offset so resumed sessions produce monotonically
|
||||
# increasing iteration values — preventing frontend message ID collisions.
|
||||
iteration_offset = 0
|
||||
last_phase = ""
|
||||
events_path = queen_dir / "events.jsonl"
|
||||
try:
|
||||
if events_path.exists():
|
||||
@@ -934,17 +977,25 @@ class SessionManager:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
it = evt.get("data", {}).get("iteration")
|
||||
data = evt.get("data", {})
|
||||
it = data.get("iteration")
|
||||
if isinstance(it, int) and it > max_iter:
|
||||
max_iter = it
|
||||
# Track the latest queen phase from QUEEN_PHASE_CHANGED events
|
||||
if evt.get("type") == "queen_phase_changed":
|
||||
phase = data.get("phase")
|
||||
if phase:
|
||||
last_phase = phase
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
if max_iter >= 0:
|
||||
iteration_offset = max_iter + 1
|
||||
logger.info(
|
||||
"Session '%s' resuming with iteration_offset=%d (from events.jsonl max)",
|
||||
"Session '%s' resuming with iteration_offset=%d"
|
||||
" (from events.jsonl max), last phase: %s",
|
||||
session.id,
|
||||
iteration_offset,
|
||||
last_phase or "unknown",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
@@ -996,10 +1047,17 @@ class SessionManager:
|
||||
_consolidation_session_dir = queen_dir
|
||||
|
||||
async def _on_compaction(_event) -> None:
|
||||
# Only consolidate on queen compactions — worker and subagent
|
||||
# compactions are frequent and don't warrant a memory update.
|
||||
if getattr(_event, "stream_id", None) != "queen":
|
||||
return
|
||||
from framework.agents.queen.queen_memory import consolidate_queen_memory
|
||||
|
||||
await consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
asyncio.create_task(
|
||||
consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
),
|
||||
name=f"queen-memory-consolidation-{session.id}",
|
||||
)
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
@@ -1,26 +1,50 @@
|
||||
"""Hive Agent Skills — discovery, parsing, and injection of SKILL.md packages.
|
||||
"""Hive Agent Skills — discovery, parsing, trust gating, and injection of SKILL.md packages.
|
||||
|
||||
Implements the open Agent Skills standard (agentskills.io) for portable
|
||||
skill discovery and activation, plus built-in default skills for runtime
|
||||
operational discipline.
|
||||
operational discipline, and AS-13 trust gating for project-scope skills.
|
||||
"""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.installer import (
|
||||
fork_skill,
|
||||
install_from_git,
|
||||
install_from_registry,
|
||||
remove_skill,
|
||||
)
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.models import TrustStatus
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.registry import RegistryClient
|
||||
from framework.skills.skill_errors import SkillError, SkillErrorCode, log_skill_error
|
||||
from framework.skills.trust import TrustedRepoStore, TrustGate
|
||||
from framework.skills.validator import ValidationResult, validate_strict
|
||||
|
||||
__all__ = [
|
||||
"DefaultSkillConfig",
|
||||
"DefaultSkillManager",
|
||||
"DiscoveryConfig",
|
||||
"ParsedSkill",
|
||||
"RegistryClient",
|
||||
"SkillCatalog",
|
||||
"SkillDiscovery",
|
||||
"SkillError",
|
||||
"SkillErrorCode",
|
||||
"SkillsConfig",
|
||||
"SkillsManager",
|
||||
"SkillsManagerConfig",
|
||||
"TrustGate",
|
||||
"TrustedRepoStore",
|
||||
"TrustStatus",
|
||||
"ValidationResult",
|
||||
"fork_skill",
|
||||
"install_from_git",
|
||||
"install_from_registry",
|
||||
"log_skill_error",
|
||||
"parse_skill_md",
|
||||
"remove_skill",
|
||||
"validate_strict",
|
||||
]
|
||||
|
||||
@@ -20,3 +20,5 @@ What to extract: URLs and key snippets (not full pages), relevant API fields
|
||||
|
||||
Before transitioning to the next phase/node, write a handoff summary to
|
||||
`_handoff_context` with everything the next phase needs to know.
|
||||
|
||||
You will receive an alert when context reaches {{warn_at_usage_ratio_pct}}% — preserve immediately.
|
||||
|
||||
@@ -14,5 +14,5 @@ When a tool call fails:
|
||||
2. Decide — transient: retry once. Structural fixable: fix and retry.
|
||||
Structural unfixable: record as failed, move to next item.
|
||||
Blocking all progress: record escalation note.
|
||||
3. Adapt — if same tool failed 3+ times, stop using it and find alternative.
|
||||
3. Adapt — if same tool failed {{max_retries_per_tool}}+ times, stop using it and find alternative.
|
||||
Update plan in notes. Never silently drop the failed item.
|
||||
|
||||
@@ -8,7 +8,7 @@ metadata:
|
||||
|
||||
## Operational Protocol: Quality Self-Assessment
|
||||
|
||||
Every 5 iterations, self-assess:
|
||||
Every {{assessment_interval}} iterations, self-assess:
|
||||
|
||||
1. On-task? Still working toward the stated objective?
|
||||
2. Thorough? Cutting corners compared to earlier?
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -76,6 +77,7 @@ class SkillCatalog:
|
||||
lines.append(f" <name>{escape(skill.name)}</name>")
|
||||
lines.append(f" <description>{escape(skill.description)}</description>")
|
||||
lines.append(f" <location>{escape(skill.location)}</location>")
|
||||
lines.append(f" <base_dir>{escape(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</available_skills>")
|
||||
|
||||
@@ -96,7 +98,14 @@ class SkillCatalog:
|
||||
for name in skill_names:
|
||||
skill = self.get(name)
|
||||
if skill is None:
|
||||
logger.warning("Pre-activated skill '%s' not found in catalog", name)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"Pre-activated skill '{name}' not found in catalog",
|
||||
why="The skill was listed for pre-activation but was not discovered.",
|
||||
fix=f"Check that a SKILL.md for '{name}' exists in a scanned directory.",
|
||||
)
|
||||
continue
|
||||
if self.is_activated(name):
|
||||
continue # Already activated, skip duplicate
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,15 +8,67 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default skills directory relative to this module
|
||||
_DEFAULT_SKILLS_DIR = Path(__file__).parent / "_default_skills"
|
||||
|
||||
# Default config values per skill — used for {{placeholder}} substitution
|
||||
_SKILL_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"hive.quality-monitor": {"assessment_interval": 5},
|
||||
"hive.error-recovery": {"max_retries_per_tool": 3},
|
||||
"hive.context-preservation": {"warn_at_usage_ratio_pct": 45},
|
||||
"hive.batch-ledger": {"checkpoint_every_n": 5},
|
||||
}
|
||||
|
||||
# Keywords that indicate a batch processing scenario (DS-12)
|
||||
_BATCH_KEYWORDS: tuple[str, ...] = (
|
||||
"list of",
|
||||
"collection of",
|
||||
"set of",
|
||||
"batch of",
|
||||
"each item",
|
||||
"for each",
|
||||
"process all",
|
||||
"records",
|
||||
"entries",
|
||||
"rows",
|
||||
"items",
|
||||
)
|
||||
|
||||
_BATCH_INIT_NUDGE = (
|
||||
"Note: your input appears to describe a batch operation. "
|
||||
"Initialize `_batch_ledger` with the total item count before processing."
|
||||
)
|
||||
|
||||
|
||||
def is_batch_scenario(text: str) -> bool:
|
||||
"""Return True if *text* contains batch-processing indicators (DS-12)."""
|
||||
lower = text.lower()
|
||||
return any(kw in lower for kw in _BATCH_KEYWORDS)
|
||||
|
||||
|
||||
def _apply_overrides(skill_name: str, body: str, overrides: dict[str, Any]) -> str:
|
||||
"""Substitute {{placeholder}} values in a skill body using overrides + defaults."""
|
||||
defaults = _SKILL_DEFAULTS.get(skill_name, {})
|
||||
# Convert float warn_at_usage_ratio → warn_at_usage_ratio_pct for the placeholder
|
||||
if "warn_at_usage_ratio" in overrides:
|
||||
overrides = dict(overrides)
|
||||
overrides.setdefault(
|
||||
"warn_at_usage_ratio_pct", int(float(overrides["warn_at_usage_ratio"]) * 100)
|
||||
)
|
||||
values = {**defaults, **overrides}
|
||||
for key, val in values.items():
|
||||
body = body.replace(f"{{{{{key}}}}}", str(val))
|
||||
return body
|
||||
|
||||
|
||||
# Ordered list of default skills (name → directory)
|
||||
SKILL_REGISTRY: dict[str, str] = {
|
||||
"hive.note-taking": "note-taking",
|
||||
@@ -60,12 +112,14 @@ class DefaultSkillManager:
|
||||
self._config = config or SkillsConfig()
|
||||
self._skills: dict[str, ParsedSkill] = {}
|
||||
self._loaded = False
|
||||
self._error_count = 0
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load all enabled default skill SKILL.md files."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
error_count = 0
|
||||
for skill_name, dir_name in SKILL_REGISTRY.items():
|
||||
if not self._config.is_default_enabled(skill_name):
|
||||
logger.info("Default skill '%s' disabled by config", skill_name)
|
||||
@@ -73,17 +127,34 @@ class DefaultSkillManager:
|
||||
|
||||
skill_path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
if not skill_path.is_file():
|
||||
logger.error("Default skill SKILL.md not found: %s", skill_path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"Default skill SKILL.md not found: '{skill_path}'",
|
||||
why=f"The framework skill '{skill_name}' is missing its SKILL.md file.",
|
||||
fix="Reinstall the hive framework — this file is part of the package.",
|
||||
)
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
parsed = parse_skill_md(skill_path, source_scope="framework")
|
||||
if parsed is None:
|
||||
logger.error("Failed to parse default skill: %s", skill_path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Failed to parse default skill '{skill_name}'",
|
||||
why=f"parse_skill_md returned None for '{skill_path}'.",
|
||||
fix="Reinstall the hive framework — this file may be corrupted.",
|
||||
)
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
self._skills[skill_name] = parsed
|
||||
|
||||
self._loaded = True
|
||||
self._error_count = error_count
|
||||
|
||||
def build_protocols_prompt(self) -> str:
|
||||
"""Build the combined operational protocols section.
|
||||
@@ -103,8 +174,10 @@ class DefaultSkillManager:
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill is None:
|
||||
continue
|
||||
# Use the full body — each SKILL.md contains exactly one protocol section
|
||||
parts.append(skill.body)
|
||||
# Apply config overrides to {{placeholder}} values before injection
|
||||
overrides = self._config.get_default_overrides(skill_name)
|
||||
body = _apply_overrides(skill_name, skill.body, overrides)
|
||||
parts.append(body)
|
||||
|
||||
if len(parts) <= 1:
|
||||
return ""
|
||||
@@ -127,8 +200,23 @@ class DefaultSkillManager:
|
||||
"""Log which default skills are active and their configuration."""
|
||||
if not self._skills:
|
||||
logger.info("Default skills: all disabled")
|
||||
return
|
||||
|
||||
# DX-3: Per-skill structured startup log
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
overrides = self._config.get_default_overrides(skill_name)
|
||||
status = f"loaded overrides={overrides}" if overrides else "loaded"
|
||||
elif not self._config.is_default_enabled(skill_name):
|
||||
status = "disabled"
|
||||
else:
|
||||
status = "error"
|
||||
logger.info(
|
||||
"skill_startup name=%s scope=framework status=%s",
|
||||
skill_name,
|
||||
status,
|
||||
)
|
||||
|
||||
# Original active skills log line (preserved for backward compatibility)
|
||||
active = []
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
@@ -138,7 +226,21 @@ class DefaultSkillManager:
|
||||
else:
|
||||
active.append(skill_name)
|
||||
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
if active:
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
|
||||
# DX-3: Summary line with error count
|
||||
total = len(SKILL_REGISTRY)
|
||||
active_count = len(self._skills)
|
||||
error_count = getattr(self, "_error_count", 0)
|
||||
disabled_count = total - active_count - error_count
|
||||
logger.info(
|
||||
"Skills: %d default (%d active, %d disabled, %d error)",
|
||||
total,
|
||||
active_count,
|
||||
disabled_count,
|
||||
error_count,
|
||||
)
|
||||
|
||||
@property
|
||||
def active_skill_names(self) -> list[str]:
|
||||
@@ -149,3 +251,28 @@ class DefaultSkillManager:
|
||||
def active_skills(self) -> dict[str, ParsedSkill]:
|
||||
"""All active default skills keyed by name."""
|
||||
return dict(self._skills)
|
||||
|
||||
@property
|
||||
def batch_init_nudge(self) -> str | None:
|
||||
"""Nudge text to prepend to system prompt when batch input detected (DS-12).
|
||||
|
||||
Returns None if ``hive.batch-ledger`` is disabled or auto_detect_batch is False.
|
||||
"""
|
||||
if "hive.batch-ledger" not in self._skills:
|
||||
return None
|
||||
overrides = self._config.get_default_overrides("hive.batch-ledger")
|
||||
if overrides.get("auto_detect_batch") is False:
|
||||
return None
|
||||
return _BATCH_INIT_NUDGE
|
||||
|
||||
@property
|
||||
def context_warn_ratio(self) -> float | None:
|
||||
"""Token usage ratio at which to inject a context preservation warning (DS-13).
|
||||
|
||||
Returns None if ``hive.context-preservation`` is disabled.
|
||||
Defaults to 0.45 when the skill is active but no override is set.
|
||||
"""
|
||||
if "hive.context-preservation" not in self._skills:
|
||||
return None
|
||||
overrides = self._config.get_default_overrides("hive.context-preservation")
|
||||
return float(overrides.get("warn_at_usage_ratio", 0.45))
|
||||
|
||||
@@ -11,6 +11,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -172,11 +173,13 @@ class SkillDiscovery:
|
||||
for skill in skills:
|
||||
if skill.name in seen:
|
||||
existing = seen[skill.name]
|
||||
logger.warning(
|
||||
"Skill name collision: '%s' from %s overrides %s",
|
||||
skill.name,
|
||||
skill.location,
|
||||
existing.location,
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_COLLISION,
|
||||
what=f"Skill name collision: '{skill.name}'",
|
||||
why=f"'{skill.location}' overrides '{existing.location}'.",
|
||||
fix="Rename one of the conflicting skill directories to use a unique name.",
|
||||
)
|
||||
seen[skill.name] = skill
|
||||
|
||||
|
||||
@@ -0,0 +1,348 @@
|
||||
"""Skill install, remove, and fork operations.
|
||||
|
||||
Handles filesystem operations for the hive skill CLI:
|
||||
- install_from_git: git clone --depth=1 → copy to target directory
|
||||
- install_from_registry: resolve registry entry → delegate to install_from_git
|
||||
- remove_skill: delete a skill from ~/.hive/skills/
|
||||
- fork_skill: copy a skill to a new location with a new name
|
||||
- maybe_show_install_notice: one-time security notice on first install (NFR-5)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.skill_errors import SkillError, SkillErrorCode
|
||||
|
||||
# Default install destination for user-scope skills
|
||||
USER_SKILLS_DIR = Path.home() / ".hive" / "skills"
|
||||
|
||||
# Sentinel file for the one-time security notice on first install (NFR-5)
|
||||
INSTALL_NOTICE_SENTINEL = Path.home() / ".hive" / ".install_notice_shown"
|
||||
|
||||
_INSTALL_NOTICE = """\
|
||||
─────────────────────────────────────────────────────────────
|
||||
Security Notice: Installing Third-Party Skills
|
||||
─────────────────────────────────────────────────────────────
|
||||
Skills are instructions executed by AI agents. A malicious
|
||||
skill can manipulate agent behavior, exfiltrate data, or
|
||||
cause unintended actions.
|
||||
|
||||
Only install skills from sources you trust. Review the
|
||||
SKILL.md before running it in a production environment.
|
||||
|
||||
This notice is shown once. Use 'hive skill doctor' to audit
|
||||
installed skills at any time.
|
||||
─────────────────────────────────────────────────────────────
|
||||
"""
|
||||
|
||||
|
||||
def maybe_show_install_notice() -> None:
|
||||
"""Print a one-time security notice before the first skill install (NFR-5).
|
||||
|
||||
Touches a sentinel file in ~/.hive/ after showing the notice so it is
|
||||
only displayed once across all future installs.
|
||||
"""
|
||||
if INSTALL_NOTICE_SENTINEL.exists():
|
||||
return
|
||||
print(_INSTALL_NOTICE, flush=True)
|
||||
try:
|
||||
INSTALL_NOTICE_SENTINEL.parent.mkdir(parents=True, exist_ok=True)
|
||||
INSTALL_NOTICE_SENTINEL.touch()
|
||||
except OSError:
|
||||
pass # If we can't write the sentinel, just show the notice every time
|
||||
|
||||
|
||||
def install_from_git(
|
||||
git_url: str,
|
||||
skill_name: str,
|
||||
subdirectory: str | None = None,
|
||||
version: str | None = None,
|
||||
target_dir: Path | None = None,
|
||||
) -> Path:
|
||||
"""Install a skill from a git repository.
|
||||
|
||||
Clones the repository with --depth=1 into a temporary directory, then
|
||||
copies the skill subdirectory (or repo root) to the target location.
|
||||
|
||||
Args:
|
||||
git_url: Git repository URL to clone.
|
||||
skill_name: Name of the skill — used as the install directory name.
|
||||
subdirectory: Relative path within the repo to the skill directory.
|
||||
If None, the repo root is treated as the skill directory.
|
||||
version: Git ref to checkout (tag, branch, or commit). Defaults to
|
||||
the remote's default branch.
|
||||
target_dir: Where to install the skill. Defaults to
|
||||
~/.hive/skills/<skill_name>/.
|
||||
|
||||
Returns:
|
||||
Path to the installed skill directory (the parent of SKILL.md).
|
||||
|
||||
Raises:
|
||||
SkillError: On any failure (git not found, clone failed, SKILL.md missing).
|
||||
"""
|
||||
if shutil.which("git") is None:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot install '{skill_name}' from {git_url}",
|
||||
why="git is not installed or not on PATH.",
|
||||
fix="Install git (https://git-scm.com/) and retry.",
|
||||
)
|
||||
|
||||
dest = (target_dir or USER_SKILLS_DIR) / skill_name
|
||||
if dest.exists():
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot install '{skill_name}'",
|
||||
why=f"Directory already exists: {dest}",
|
||||
fix=f"Run 'hive skill remove {skill_name}' first, or use a different --name.",
|
||||
)
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix="hive-skill-install-")
|
||||
try:
|
||||
_git_clone_shallow(git_url, Path(tmp_dir), version=version)
|
||||
|
||||
# Locate the skill within the cloned repo
|
||||
source_dir = Path(tmp_dir) / subdirectory if subdirectory else Path(tmp_dir)
|
||||
skill_md = source_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"No SKILL.md found in '{subdirectory or '/'}' of {git_url}",
|
||||
why="The expected SKILL.md file is not present at the given path.",
|
||||
fix=(
|
||||
"Check the repository structure and use "
|
||||
"'hive skill install --from <url>' with the correct subdirectory."
|
||||
),
|
||||
)
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
_copy_skill_dir(source_dir, dest)
|
||||
return dest
|
||||
|
||||
except SkillError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to install '{skill_name}' from {git_url}",
|
||||
why=str(exc),
|
||||
fix="Check the URL, your network connection, and git configuration.",
|
||||
) from exc
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def install_from_registry(
|
||||
registry_entry: dict,
|
||||
target_dir: Path | None = None,
|
||||
version: str | None = None,
|
||||
) -> Path:
|
||||
"""Install a skill using a registry index entry.
|
||||
|
||||
Resolves the git_url and subdirectory from the registry entry and
|
||||
delegates to install_from_git.
|
||||
|
||||
Args:
|
||||
registry_entry: A skill entry dict from skill_index.json.
|
||||
target_dir: Override install destination.
|
||||
version: Override version (defaults to entry's 'version' field).
|
||||
|
||||
Returns:
|
||||
Path to the installed skill directory.
|
||||
|
||||
Raises:
|
||||
SkillError: If the registry entry is missing required fields or install fails.
|
||||
"""
|
||||
name = registry_entry.get("name")
|
||||
git_url = registry_entry.get("git_url")
|
||||
|
||||
if not name or not git_url:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what="Incomplete registry entry — missing 'name' or 'git_url'.",
|
||||
why="The registry index entry does not contain all required fields.",
|
||||
fix="Report this issue to the registry maintainer.",
|
||||
)
|
||||
|
||||
resolved_version = version or registry_entry.get("version")
|
||||
subdirectory = registry_entry.get("subdirectory")
|
||||
|
||||
return install_from_git(
|
||||
git_url=git_url,
|
||||
skill_name=str(name),
|
||||
subdirectory=subdirectory,
|
||||
version=resolved_version,
|
||||
target_dir=target_dir,
|
||||
)
|
||||
|
||||
|
||||
def remove_skill(name: str, skills_dir: Path | None = None) -> bool:
|
||||
"""Remove an installed skill from the user skills directory.
|
||||
|
||||
Args:
|
||||
name: Skill directory name to remove.
|
||||
skills_dir: Override the search directory (default: ~/.hive/skills/).
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found.
|
||||
|
||||
Raises:
|
||||
SkillError: If the directory exists but cannot be removed.
|
||||
"""
|
||||
target = (skills_dir or USER_SKILLS_DIR) / name
|
||||
if not target.exists():
|
||||
return False
|
||||
try:
|
||||
shutil.rmtree(target)
|
||||
return True
|
||||
except OSError as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to remove skill '{name}' at {target}",
|
||||
why=str(exc),
|
||||
fix="Check file permissions and try again.",
|
||||
) from exc
|
||||
|
||||
|
||||
def fork_skill(
|
||||
source: ParsedSkill,
|
||||
new_name: str,
|
||||
target_dir: Path,
|
||||
) -> Path:
|
||||
"""Create a local editable copy of a skill with a new name.
|
||||
|
||||
Copies the skill's base directory to target_dir/new_name/ and rewrites
|
||||
the 'name' field in the copied SKILL.md frontmatter.
|
||||
|
||||
Args:
|
||||
source: The source skill to fork (from SkillDiscovery).
|
||||
new_name: Name for the forked skill.
|
||||
target_dir: Parent directory for the fork (e.g. ~/.hive/skills/).
|
||||
|
||||
Returns:
|
||||
Path to the forked skill directory.
|
||||
|
||||
Raises:
|
||||
SkillError: If the target already exists or the copy fails.
|
||||
"""
|
||||
dest = target_dir / new_name
|
||||
if dest.exists():
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot fork to '{dest}'",
|
||||
why="Target directory already exists.",
|
||||
fix=f"Choose a different --name or remove '{dest}' first.",
|
||||
)
|
||||
|
||||
source_dir = Path(source.base_dir)
|
||||
try:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
_copy_skill_dir(source_dir, dest)
|
||||
except OSError as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to fork skill '{source.name}' to '{dest}'",
|
||||
why=str(exc),
|
||||
fix="Check file permissions and available disk space.",
|
||||
) from exc
|
||||
|
||||
# Rewrite the name in the forked SKILL.md via YAML round-trip (safe)
|
||||
forked_skill_md = dest / "SKILL.md"
|
||||
if forked_skill_md.exists():
|
||||
_rewrite_name_in_skill_md(forked_skill_md, new_name)
|
||||
|
||||
return dest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _git_clone_shallow(git_url: str, target: Path, version: str | None = None) -> None:
|
||||
"""Clone a git repo at --depth=1 into target directory.
|
||||
|
||||
Args:
|
||||
git_url: Repository URL.
|
||||
target: Destination directory (will be created by git).
|
||||
version: Optional git ref (branch/tag) to clone.
|
||||
|
||||
Raises:
|
||||
SkillError: If the clone fails.
|
||||
"""
|
||||
cmd = ["git", "clone", "--depth=1"]
|
||||
if version:
|
||||
cmd += ["--branch", version]
|
||||
cmd += [git_url, str(target)]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"git clone timed out for {git_url}",
|
||||
why="The clone operation took longer than 60 seconds.",
|
||||
fix="Check your network connection and retry.",
|
||||
) from None
|
||||
except (FileNotFoundError, OSError) as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot run git for {git_url}",
|
||||
why=str(exc),
|
||||
fix="Ensure git is installed and on PATH.",
|
||||
) from exc
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.strip()
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"git clone failed for {git_url}",
|
||||
why=stderr or f"git exited with code {result.returncode}",
|
||||
fix="Check the URL is correct and the repository is publicly accessible.",
|
||||
)
|
||||
|
||||
|
||||
def _copy_skill_dir(src: Path, dst: Path) -> None:
|
||||
"""Copy a skill directory, ignoring VCS and cache artifacts."""
|
||||
ignore = shutil.ignore_patterns(".git", "__pycache__", "*.pyc", ".venv", "venv", "node_modules")
|
||||
shutil.copytree(src, dst, ignore=ignore)
|
||||
|
||||
|
||||
def _rewrite_name_in_skill_md(skill_md: Path, new_name: str) -> None:
|
||||
"""Rewrite the 'name' field in a SKILL.md frontmatter via YAML round-trip.
|
||||
|
||||
Parses the frontmatter with yaml.safe_load, updates 'name', re-serializes
|
||||
with yaml.dump, and reconstructs the file as:
|
||||
---
|
||||
<yaml>
|
||||
---
|
||||
<body>
|
||||
|
||||
Falls back to no-op if the file can't be parsed (the copy is still usable).
|
||||
"""
|
||||
import yaml
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
return
|
||||
frontmatter = yaml.safe_load(parts[1].strip())
|
||||
if not isinstance(frontmatter, dict):
|
||||
return
|
||||
frontmatter["name"] = new_name
|
||||
new_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True)
|
||||
new_content = f"---\n{new_yaml}---\n{parts[2]}"
|
||||
skill_md.write_text(new_content, encoding="utf-8")
|
||||
except Exception:
|
||||
pass # Degraded: forked copy works, name just isn't updated
|
||||
@@ -42,11 +42,14 @@ class SkillsManagerConfig:
|
||||
When ``None``, community discovery is skipped.
|
||||
skip_community_discovery: Explicitly skip community scanning
|
||||
even when ``project_root`` is set.
|
||||
interactive: Whether trust gating can prompt the user interactively.
|
||||
When ``False``, untrusted project skills are silently skipped.
|
||||
"""
|
||||
|
||||
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
|
||||
project_root: Path | None = None
|
||||
skip_community_discovery: bool = False
|
||||
interactive: bool = True
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
@@ -63,6 +66,8 @@ class SkillsManager:
|
||||
self._loaded = False
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
self._allowlisted_dirs: list[str] = []
|
||||
self._default_mgr: object = None # DefaultSkillManager, set after load()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
@@ -85,6 +90,8 @@ class SkillsManager:
|
||||
mgr._loaded = True # skip load()
|
||||
mgr._catalog_prompt = skills_catalog_prompt
|
||||
mgr._protocols_prompt = protocols_prompt
|
||||
mgr._allowlisted_dirs = []
|
||||
mgr._default_mgr = None
|
||||
return mgr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -113,9 +120,18 @@ class SkillsManager:
|
||||
# 1. Community skill discovery (when project_root is available)
|
||||
catalog_prompt = ""
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
from framework.skills.trust import TrustGate
|
||||
|
||||
discovery = SkillDiscovery(DiscoveryConfig(project_root=self._config.project_root))
|
||||
discovered = discovery.discover()
|
||||
|
||||
# Trust-gate project-scope skills (AS-13)
|
||||
discovered = TrustGate(interactive=self._config.interactive).filter_and_gate(
|
||||
discovered, project_dir=self._config.project_root
|
||||
)
|
||||
|
||||
catalog = SkillCatalog(discovered)
|
||||
self._allowlisted_dirs = catalog.allowlisted_dirs
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
|
||||
# Pre-activated community skills
|
||||
@@ -132,6 +148,17 @@ class SkillsManager:
|
||||
default_mgr.load()
|
||||
default_mgr.log_active_skills()
|
||||
protocols_prompt = default_mgr.build_protocols_prompt()
|
||||
self._default_mgr = default_mgr
|
||||
# DX-3: Community skill startup summary
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
community_count = len(catalog._skills) if catalog_prompt else 0
|
||||
pre_activated_count = len(skills_config.skills) if skills_config.skills else 0
|
||||
logger.info(
|
||||
"Skills: %d community (%d catalog, %d pre-activated)",
|
||||
community_count,
|
||||
community_count,
|
||||
pre_activated_count,
|
||||
)
|
||||
|
||||
# 3. Cache
|
||||
self._catalog_prompt = catalog_prompt
|
||||
@@ -160,6 +187,25 @@ class SkillsManager:
|
||||
"""Default skill operational protocols for system prompt injection."""
|
||||
return self._protocols_prompt
|
||||
|
||||
@property
|
||||
def allowlisted_dirs(self) -> list[str]:
|
||||
"""Skill base directories for Tier 3 resource access (AS-6)."""
|
||||
return self._allowlisted_dirs
|
||||
|
||||
@property
|
||||
def batch_init_nudge(self) -> str | None:
|
||||
"""Batch init nudge text for DS-12 auto-detection, or None if disabled."""
|
||||
if self._default_mgr is None:
|
||||
return None
|
||||
return self._default_mgr.batch_init_nudge # type: ignore[union-attr]
|
||||
|
||||
@property
|
||||
def context_warn_ratio(self) -> float | None:
|
||||
"""Token usage ratio for DS-13 context preservation warning, or None if disabled."""
|
||||
if self._default_mgr is None:
|
||||
return None
|
||||
return self._default_mgr.context_warn_ratio # type: ignore[union-attr]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Data models for the Hive skill system (Agent Skills standard)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SkillScope(StrEnum):
|
||||
"""Where a skill was discovered."""
|
||||
|
||||
PROJECT = "project"
|
||||
USER = "user"
|
||||
FRAMEWORK = "framework"
|
||||
|
||||
|
||||
class TrustStatus(StrEnum):
|
||||
"""Trust state of a skill entry."""
|
||||
|
||||
TRUSTED = "trusted"
|
||||
PENDING_CONSENT = "pending_consent"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""In-memory record for a discovered skill (PRD §4.2)."""
|
||||
|
||||
name: str
|
||||
"""Skill name from SKILL.md frontmatter."""
|
||||
|
||||
description: str
|
||||
"""Skill description from SKILL.md frontmatter."""
|
||||
|
||||
location: Path
|
||||
"""Absolute path to SKILL.md."""
|
||||
|
||||
base_dir: Path
|
||||
"""Parent directory of SKILL.md (skill root)."""
|
||||
|
||||
source_scope: SkillScope
|
||||
"""Which scope this skill was found in."""
|
||||
|
||||
trust_status: TrustStatus = TrustStatus.TRUSTED
|
||||
"""Trust state; project-scope skills start as PENDING_CONSENT before gating."""
|
||||
|
||||
# Optional frontmatter fields
|
||||
license: str | None = None
|
||||
compatibility: list[str] = field(default_factory=list)
|
||||
allowed_tools: list[str] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
@@ -13,6 +13,8 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum name length before a warning is logged
|
||||
@@ -74,17 +76,38 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.error("Failed to read %s: %s", path, exc)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to read '{path}'",
|
||||
why=str(exc),
|
||||
fix="Check the file exists and has read permissions.",
|
||||
)
|
||||
return None
|
||||
|
||||
if not content.strip():
|
||||
logger.error("Empty SKILL.md: %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="The file exists but contains no content.",
|
||||
fix="Add valid YAML frontmatter and a markdown body to the SKILL.md.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Split on --- delimiters (first two occurrences)
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
logger.error("SKILL.md missing YAML frontmatter delimiters (---): %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="Missing YAML frontmatter (---).",
|
||||
fix="Wrap the frontmatter with --- on its own line at the top and bottom.",
|
||||
)
|
||||
return None
|
||||
|
||||
# parts[0] is content before first --- (should be empty or whitespace)
|
||||
@@ -94,7 +117,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
body = parts[2].strip()
|
||||
|
||||
if not raw_yaml:
|
||||
logger.error("Empty YAML frontmatter in %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="The --- delimiters are present but the YAML block is empty.",
|
||||
fix="Add at least 'name' and 'description' fields to the frontmatter.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse YAML
|
||||
@@ -108,19 +138,47 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
try:
|
||||
fixed = _try_fix_yaml(raw_yaml)
|
||||
frontmatter = yaml.safe_load(fixed)
|
||||
logger.warning("Fixed YAML parse issues in %s (unquoted colons)", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_YAML_FIXUP,
|
||||
what=f"Auto-fixed YAML in '{path}'",
|
||||
why="Unquoted colon values detected in frontmatter.",
|
||||
fix='Wrap values containing colons in quotes e.g. description: "Use for: research"',
|
||||
)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.error("Unparseable YAML in %s: %s", path, exc)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why=str(exc),
|
||||
fix="Validate the YAML frontmatter at https://yaml-online-parser.appspot.com/",
|
||||
)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter, dict):
|
||||
logger.error("YAML frontmatter is not a mapping in %s", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_PARSE_ERROR,
|
||||
what=f"Invalid SKILL.md at '{path}'",
|
||||
why="YAML frontmatter is not a key-value mapping.",
|
||||
fix="Ensure the frontmatter is valid YAML with key: value pairs.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Required: description
|
||||
description = frontmatter.get("description")
|
||||
if not description or not str(description).strip():
|
||||
logger.error("Missing or empty 'description' in %s — skipping skill", path)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"error",
|
||||
SkillErrorCode.SKILL_MISSING_DESCRIPTION,
|
||||
what=f"Missing 'description' in '{path}'",
|
||||
why="The 'description' field is required but is absent or empty.",
|
||||
fix="Add a non-empty 'description' field to the YAML frontmatter.",
|
||||
)
|
||||
return None
|
||||
|
||||
# Required: name (fallback to parent directory name)
|
||||
@@ -128,7 +186,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
parent_dir_name = path.parent.name
|
||||
if not name or not str(name).strip():
|
||||
name = parent_dir_name
|
||||
logger.warning("Missing 'name' in %s — using directory name '%s'", path, name)
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NAME_MISMATCH,
|
||||
what=f"Missing 'name' in '{path}' — using directory name '{name}'",
|
||||
why="The 'name' field is absent from the YAML frontmatter.",
|
||||
fix=f"Add 'name: {name}' to the frontmatter to make this explicit.",
|
||||
)
|
||||
else:
|
||||
name = str(name).strip()
|
||||
|
||||
@@ -137,13 +202,24 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
logger.warning("Skill name exceeds %d chars in %s: '%s'", _MAX_NAME_LENGTH, path, name)
|
||||
|
||||
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
|
||||
logger.warning(
|
||||
"Skill name '%s' doesn't match parent directory '%s' in %s",
|
||||
name,
|
||||
parent_dir_name,
|
||||
path,
|
||||
log_skill_error(
|
||||
logger,
|
||||
"warning",
|
||||
SkillErrorCode.SKILL_NAME_MISMATCH,
|
||||
what=f"Name mismatch in '{path}'",
|
||||
why=f"Skill name '{name}' doesn't match directory '{parent_dir_name}'.",
|
||||
fix=f"Rename the directory to '{name}' or set name to '{parent_dir_name}'.",
|
||||
)
|
||||
|
||||
# Coerce compatibility / allowed-tools to list[str] — many SKILL.md files
|
||||
# in the wild use a plain string instead of a YAML list.
|
||||
raw_compat = frontmatter.get("compatibility")
|
||||
if isinstance(raw_compat, str):
|
||||
raw_compat = [raw_compat]
|
||||
raw_tools = frontmatter.get("allowed-tools")
|
||||
if isinstance(raw_tools, str):
|
||||
raw_tools = [raw_tools]
|
||||
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=str(description).strip(),
|
||||
@@ -152,7 +228,7 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
source_scope=source_scope,
|
||||
body=body,
|
||||
license=frontmatter.get("license"),
|
||||
compatibility=frontmatter.get("compatibility"),
|
||||
compatibility=raw_compat,
|
||||
metadata=frontmatter.get("metadata"),
|
||||
allowed_tools=frontmatter.get("allowed-tools"),
|
||||
allowed_tools=raw_tools,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Registry client for the Hive community skill registry.
|
||||
|
||||
Fetches the skill index from the hive-skill-registry GitHub repo, caches it
|
||||
locally, and provides search and resolution utilities.
|
||||
|
||||
The registry repo (Phase 3) may not exist yet. All public methods degrade
|
||||
gracefully — returning None or [] on any network or parse failure.
|
||||
|
||||
Configure a custom registry URL via the HIVE_REGISTRY_URL environment variable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import urlopen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default registry index URL (Phase 3 repo, may not exist yet)
|
||||
_DEFAULT_REGISTRY_URL = (
|
||||
"https://raw.githubusercontent.com/hive-skill-registry/"
|
||||
"hive-skill-registry/main/skill_index.json"
|
||||
)
|
||||
|
||||
_CACHE_DIR = Path.home() / ".hive" / "registry_cache"
|
||||
_CACHE_INDEX_PATH = _CACHE_DIR / "skill_index.json"
|
||||
_CACHE_METADATA_PATH = _CACHE_DIR / "metadata.json"
|
||||
_CACHE_TTL_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
class RegistryClient:
|
||||
"""Client for the Hive community skill registry.
|
||||
|
||||
All public methods return None / [] on any failure — never raise.
|
||||
Network errors, parse failures, and missing registries are all
|
||||
treated as graceful degradation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry_url: str | None = None,
|
||||
cache_dir: Path | None = None,
|
||||
) -> None:
|
||||
self._url = registry_url or os.environ.get("HIVE_REGISTRY_URL", _DEFAULT_REGISTRY_URL)
|
||||
cache_root = cache_dir or _CACHE_DIR
|
||||
self._index_path = cache_root / "skill_index.json"
|
||||
self._metadata_path = cache_root / "metadata.json"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def fetch_index(self, force_refresh: bool = False) -> dict | None:
|
||||
"""Return the registry index dict.
|
||||
|
||||
Uses the local cache if it is fresh (within TTL) unless
|
||||
force_refresh=True. Returns None on any failure.
|
||||
"""
|
||||
if not force_refresh and self._is_cache_fresh():
|
||||
cached = self._load_cache()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
raw = self._http_fetch(self._url)
|
||||
if raw is None:
|
||||
# Network unavailable — fall back to stale cache if present
|
||||
stale = self._load_cache()
|
||||
if stale is not None:
|
||||
logger.debug("registry: network unavailable, using stale cache")
|
||||
return stale
|
||||
|
||||
try:
|
||||
data = json.loads(raw.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
logger.warning("registry: failed to parse index JSON: %s", exc)
|
||||
return self._load_cache()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
logger.warning("registry: index is not a JSON object")
|
||||
return self._load_cache()
|
||||
|
||||
self._save_cache(data)
|
||||
return data
|
||||
|
||||
def search(self, query: str) -> list[dict]:
|
||||
"""Search registry skills by name, description, or tags.
|
||||
|
||||
Case-insensitive substring match. Returns [] if index unavailable.
|
||||
"""
|
||||
index = self.fetch_index()
|
||||
if not index:
|
||||
return []
|
||||
skills = index.get("skills", [])
|
||||
if not isinstance(skills, list):
|
||||
return []
|
||||
q = query.lower()
|
||||
results = []
|
||||
for entry in skills:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = str(entry.get("name", "")).lower()
|
||||
description = str(entry.get("description", "")).lower()
|
||||
tags = " ".join(str(t) for t in entry.get("tags", [])).lower()
|
||||
if q in name or q in description or q in tags:
|
||||
results.append(entry)
|
||||
return results
|
||||
|
||||
def get_skill_entry(self, name: str) -> dict | None:
|
||||
"""Look up a single skill by exact name. Returns None if not found."""
|
||||
index = self.fetch_index()
|
||||
if not index:
|
||||
return None
|
||||
for entry in index.get("skills", []):
|
||||
if isinstance(entry, dict) and entry.get("name") == name:
|
||||
return entry
|
||||
return None
|
||||
|
||||
def get_pack(self, pack_name: str) -> list[str] | None:
|
||||
"""Return the list of skill names in a starter pack.
|
||||
|
||||
Returns None if the pack is not found or the index is unavailable.
|
||||
"""
|
||||
index = self.fetch_index()
|
||||
if not index:
|
||||
return None
|
||||
for pack in index.get("packs", []):
|
||||
if isinstance(pack, dict) and pack.get("name") == pack_name:
|
||||
skills = pack.get("skills", [])
|
||||
if isinstance(skills, list):
|
||||
return [s for s in skills if isinstance(s, str)]
|
||||
return None
|
||||
|
||||
def resolve_git_url(self, name: str) -> tuple[str, str | None] | None:
|
||||
"""Return (git_url, subdirectory) for a skill name.
|
||||
|
||||
Returns None if the skill is not in the registry or the index
|
||||
is unavailable.
|
||||
"""
|
||||
entry = self.get_skill_entry(name)
|
||||
if not entry:
|
||||
return None
|
||||
git_url = entry.get("git_url")
|
||||
if not git_url:
|
||||
return None
|
||||
subdirectory = entry.get("subdirectory") or None
|
||||
return str(git_url), subdirectory
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cache internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_cache(self) -> dict | None:
|
||||
"""Read cached index from disk. Returns None if absent or unreadable."""
|
||||
try:
|
||||
data = json.loads(self._index_path.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, dict) else None
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("registry: could not read cache: %s", exc)
|
||||
return None
|
||||
|
||||
def _save_cache(self, data: dict) -> None:
|
||||
"""Write index to disk atomically (.tmp then rename)."""
|
||||
try:
|
||||
self._index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = self._index_path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
tmp.replace(self._index_path)
|
||||
# Update metadata
|
||||
meta = {"last_fetched": datetime.now(tz=UTC).isoformat()}
|
||||
meta_tmp = self._metadata_path.with_suffix(".tmp")
|
||||
meta_tmp.write_text(json.dumps(meta, indent=2), encoding="utf-8")
|
||||
meta_tmp.replace(self._metadata_path)
|
||||
except Exception as exc:
|
||||
logger.debug("registry: could not write cache: %s", exc)
|
||||
|
||||
def _is_cache_fresh(self) -> bool:
|
||||
"""Return True if the cached index was fetched within the TTL."""
|
||||
try:
|
||||
meta = json.loads(self._metadata_path.read_text(encoding="utf-8"))
|
||||
last_fetched = datetime.fromisoformat(meta["last_fetched"])
|
||||
age = (datetime.now(tz=UTC) - last_fetched).total_seconds()
|
||||
return age < _CACHE_TTL_SECONDS
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _http_fetch(self, url: str, timeout: int = 10) -> bytes | None:
|
||||
"""Fetch URL contents. Returns None on any network error — never raises."""
|
||||
try:
|
||||
with urlopen(url, timeout=timeout) as resp: # noqa: S310
|
||||
return resp.read()
|
||||
except URLError as exc:
|
||||
logger.debug("registry: HTTP fetch failed for %s: %s", url, exc)
|
||||
return None
|
||||
except TimeoutError as exc:
|
||||
logger.debug("registry: HTTP fetch timed out for %s: %s", url, exc)
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("registry: unexpected error fetching %s: %s", url, exc)
|
||||
return None
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Structured error codes and diagnostics for the Hive skill system.
|
||||
|
||||
Implements DX-1 (structured error codes) and DX-2 (what/why/fix format)
|
||||
from the skill system PRD §7.5.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SkillErrorCode(Enum):
|
||||
"""Standardized error codes for skill system operations (DX-1)."""
|
||||
|
||||
SKILL_NOT_FOUND = "SKILL_NOT_FOUND"
|
||||
SKILL_PARSE_ERROR = "SKILL_PARSE_ERROR"
|
||||
SKILL_ACTIVATION_FAILED = "SKILL_ACTIVATION_FAILED"
|
||||
SKILL_MISSING_DESCRIPTION = "SKILL_MISSING_DESCRIPTION"
|
||||
SKILL_YAML_FIXUP = "SKILL_YAML_FIXUP"
|
||||
SKILL_NAME_MISMATCH = "SKILL_NAME_MISMATCH"
|
||||
SKILL_COLLISION = "SKILL_COLLISION"
|
||||
|
||||
|
||||
class SkillError(Exception):
|
||||
"""Structured exception for skill system errors (DX-2).
|
||||
|
||||
Raised in strict validation paths. Also used as the base
|
||||
format contract for log_skill_error() log messages.
|
||||
"""
|
||||
|
||||
def __init__(self, code: SkillErrorCode, what: str, why: str, fix: str):
|
||||
self.code = code
|
||||
self.what = what
|
||||
self.why = why
|
||||
self.fix = fix
|
||||
self.message = (
|
||||
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def log_skill_error(
|
||||
logger: logging.Logger,
|
||||
level: str,
|
||||
code: SkillErrorCode,
|
||||
what: str,
|
||||
why: str,
|
||||
fix: str,
|
||||
) -> None:
|
||||
"""Emit a structured skill diagnostic log with consistent format (DX-2).
|
||||
|
||||
Args:
|
||||
logger: The module logger to emit to.
|
||||
level: Log level string — 'error', 'warning', or 'info'.
|
||||
code: Structured error code.
|
||||
what: What failed (specific skill name and path).
|
||||
why: Root cause.
|
||||
fix: Concrete next step for the developer.
|
||||
"""
|
||||
msg = f"[{code.value}] What failed: {what} | Why: {why} | Fix: {fix}"
|
||||
getattr(logger, level)(
|
||||
msg,
|
||||
extra={
|
||||
"skill_error_code": code.value,
|
||||
"what": what,
|
||||
"why": why,
|
||||
"fix": fix,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""Trust gating for project-level skills (PRD AS-13).
|
||||
|
||||
Project-level skills from untrusted repositories require explicit user consent
|
||||
before their instructions are loaded into the agent's system prompt.
|
||||
Framework and user-scope skills are always trusted.
|
||||
|
||||
Trusted repos are persisted at ~/.hive/trusted_repos.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Env var to bypass trust gating in CI/headless pipelines (opt-in).
|
||||
_ENV_TRUST_ALL = "HIVE_TRUST_PROJECT_SKILLS"
|
||||
|
||||
# Env var for comma-separated own-remote glob patterns (e.g. "github.com/myorg/*").
|
||||
_ENV_OWN_REMOTES = "HIVE_OWN_REMOTES"
|
||||
|
||||
_TRUSTED_REPOS_PATH = Path.home() / ".hive" / "trusted_repos.json"
|
||||
_NOTICE_SENTINEL_PATH = Path.home() / ".hive" / ".skill_trust_notice_shown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trusted repo store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrustedRepoEntry:
|
||||
repo_key: str
|
||||
added_at: datetime
|
||||
project_path: str = ""
|
||||
|
||||
|
||||
class TrustedRepoStore:
|
||||
"""Persists permanently-trusted repo keys to ~/.hive/trusted_repos.json."""
|
||||
|
||||
def __init__(self, path: Path | None = None) -> None:
|
||||
self._path = path or _TRUSTED_REPOS_PATH
|
||||
self._entries: dict[str, TrustedRepoEntry] = {}
|
||||
self._loaded = False
|
||||
|
||||
def is_trusted(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
return repo_key in self._entries
|
||||
|
||||
def trust(self, repo_key: str, project_path: str = "") -> None:
|
||||
self._ensure_loaded()
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=datetime.now(tz=UTC),
|
||||
project_path=project_path,
|
||||
)
|
||||
self._save()
|
||||
logger.info("skill_trust_store: trusted repo_key=%s", repo_key)
|
||||
|
||||
def revoke(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
if repo_key in self._entries:
|
||||
del self._entries[repo_key]
|
||||
self._save()
|
||||
logger.info("skill_trust_store: revoked repo_key=%s", repo_key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_entries(self) -> list[TrustedRepoEntry]:
|
||||
self._ensure_loaded()
|
||||
return list(self._entries.values())
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
if not self._loaded:
|
||||
self._load()
|
||||
self._loaded = True
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
data = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
for raw in data.get("entries", []):
|
||||
repo_key = raw.get("repo_key", "")
|
||||
if not repo_key:
|
||||
continue
|
||||
try:
|
||||
added_at = datetime.fromisoformat(raw["added_at"])
|
||||
except (KeyError, ValueError):
|
||||
added_at = datetime.now(tz=UTC)
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=added_at,
|
||||
project_path=raw.get("project_path", ""),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"skill_trust_store: could not read %s (%s); treating as empty",
|
||||
self._path,
|
||||
e,
|
||||
)
|
||||
|
||||
def _save(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"version": 1,
|
||||
"entries": [
|
||||
{
|
||||
"repo_key": e.repo_key,
|
||||
"added_at": e.added_at.isoformat(),
|
||||
"project_path": e.project_path,
|
||||
}
|
||||
for e in self._entries.values()
|
||||
],
|
||||
}
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp = self._path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
tmp.replace(self._path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProjectTrustClassification(StrEnum):
|
||||
ALWAYS_TRUSTED = "always_trusted"
|
||||
TRUSTED_BY_USER = "trusted_by_user"
|
||||
UNTRUSTED = "untrusted"
|
||||
|
||||
|
||||
class ProjectTrustDetector:
|
||||
"""Classifies a project directory as trusted or untrusted.
|
||||
|
||||
Algorithm (PRD §4.1 trust note):
|
||||
1. No project_dir → ALWAYS_TRUSTED
|
||||
2. No .git directory → ALWAYS_TRUSTED (not a git repo)
|
||||
3. No remote 'origin' → ALWAYS_TRUSTED (local-only repo)
|
||||
4. Remote URL → repo_key; in TrustedRepoStore → TRUSTED_BY_USER
|
||||
5. Localhost remote → ALWAYS_TRUSTED
|
||||
6. ~/.hive/own_remotes match → ALWAYS_TRUSTED
|
||||
7. HIVE_OWN_REMOTES env match → ALWAYS_TRUSTED
|
||||
8. None of the above → UNTRUSTED
|
||||
"""
|
||||
|
||||
def __init__(self, store: TrustedRepoStore | None = None) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
|
||||
def classify(self, project_dir: Path | None) -> tuple[ProjectTrustClassification, str]:
|
||||
"""Return (classification, repo_key).
|
||||
|
||||
repo_key is empty string for ALWAYS_TRUSTED cases without a remote.
|
||||
"""
|
||||
if project_dir is None or not project_dir.exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
if not (project_dir / ".git").exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
remote_url = self._get_remote_origin(project_dir)
|
||||
if not remote_url:
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
repo_key = _normalize_remote_url(remote_url)
|
||||
|
||||
# Explicitly trusted by user
|
||||
if self._store.is_trusted(repo_key):
|
||||
return ProjectTrustClassification.TRUSTED_BY_USER, repo_key
|
||||
|
||||
# Localhost remotes are always trusted
|
||||
if _is_localhost_remote(remote_url):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
# User-configured own-remote patterns
|
||||
if self._matches_own_remotes(repo_key):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
return ProjectTrustClassification.UNTRUSTED, repo_key
|
||||
|
||||
def _get_remote_origin(self, project_dir: Path) -> str:
|
||||
"""Run git remote get-url origin. Returns empty string on any failure."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(project_dir), "remote", "get-url", "origin"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
"skill_trust: git remote lookup timed out for %s; treating as trusted",
|
||||
project_dir,
|
||||
)
|
||||
except (FileNotFoundError, OSError):
|
||||
pass # git not found or other OS error
|
||||
return ""
|
||||
|
||||
def _matches_own_remotes(self, repo_key: str) -> bool:
|
||||
"""Check repo_key against user-configured own-remote glob patterns."""
|
||||
import fnmatch
|
||||
|
||||
patterns: list[str] = []
|
||||
|
||||
# From env var
|
||||
env_patterns = _ENV_OWN_REMOTES
|
||||
import os
|
||||
|
||||
raw = os.environ.get(env_patterns, "")
|
||||
if raw:
|
||||
patterns.extend(p.strip() for p in raw.split(",") if p.strip())
|
||||
|
||||
# From ~/.hive/own_remotes file
|
||||
own_remotes_file = Path.home() / ".hive" / "own_remotes"
|
||||
if own_remotes_file.is_file():
|
||||
try:
|
||||
for line in own_remotes_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
patterns.append(line)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return any(fnmatch.fnmatch(repo_key, p) for p in patterns)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL helpers (public so CLI can reuse)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_remote_url(url: str) -> str:
|
||||
"""Normalize a git remote URL to a canonical ``host/org/repo`` key.
|
||||
|
||||
Examples:
|
||||
git@github.com:org/repo.git → github.com/org/repo
|
||||
https://github.com/org/repo → github.com/org/repo
|
||||
ssh://git@github.com/org/repo.git → github.com/org/repo
|
||||
"""
|
||||
url = url.strip()
|
||||
|
||||
# SCP-style SSH: git@github.com:org/repo.git
|
||||
if url.startswith("git@") and ":" in url and "://" not in url:
|
||||
url = url[4:] # strip git@
|
||||
url = url.replace(":", "/", 1)
|
||||
elif "://" in url:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ""
|
||||
path = parsed.path.lstrip("/")
|
||||
url = f"{host}/{path}"
|
||||
|
||||
# Strip .git suffix
|
||||
if url.endswith(".git"):
|
||||
url = url[:-4]
|
||||
|
||||
return url.lower().strip("/")
|
||||
|
||||
|
||||
def _is_localhost_remote(remote_url: str) -> bool:
|
||||
"""Return True if the remote points to a local host."""
|
||||
local_hosts = {"localhost", "127.0.0.1", "::1"}
|
||||
try:
|
||||
if "://" in remote_url:
|
||||
parsed = urlparse(remote_url)
|
||||
return (parsed.hostname or "").lower() in local_hosts
|
||||
# SCP-style: git@localhost:org/repo
|
||||
if "@" in remote_url:
|
||||
host_part = remote_url.split("@", 1)[1].split(":")[0]
|
||||
return host_part.lower() in local_hosts
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TrustGate:
|
||||
"""Filters skill list, running consent flow for untrusted project-scope skills.
|
||||
|
||||
Framework and user-scope skills are always allowed through.
|
||||
Project-scope skills from untrusted repos require consent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: TrustedRepoStore | None = None,
|
||||
detector: ProjectTrustDetector | None = None,
|
||||
interactive: bool = True,
|
||||
print_fn: Callable[[str], None] | None = None,
|
||||
input_fn: Callable[[str], str] | None = None,
|
||||
) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
self._detector = detector or ProjectTrustDetector(self._store)
|
||||
self._interactive = interactive
|
||||
self._print = print_fn or print
|
||||
self._input = input_fn or input
|
||||
|
||||
def filter_and_gate(
|
||||
self,
|
||||
skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
) -> list[ParsedSkill]:
|
||||
"""Return the subset of skills that are trusted for loading.
|
||||
|
||||
- Framework and user-scope skills: always included.
|
||||
- Project-scope skills: classified; consent prompt shown if untrusted.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Separate project skills from always-trusted scopes
|
||||
always_trusted = [s for s in skills if s.source_scope != "project"]
|
||||
project_skills = [s for s in skills if s.source_scope == "project"]
|
||||
|
||||
if not project_skills:
|
||||
return always_trusted
|
||||
|
||||
# Env-var CI override: trust all project skills for this invocation
|
||||
if os.environ.get(_ENV_TRUST_ALL, "").strip() == "1":
|
||||
logger.info(
|
||||
"skill_trust: %s=1 set; trusting %d project skill(s) without consent",
|
||||
_ENV_TRUST_ALL,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
classification, repo_key = self._detector.classify(project_dir)
|
||||
|
||||
if classification in (
|
||||
ProjectTrustClassification.ALWAYS_TRUSTED,
|
||||
ProjectTrustClassification.TRUSTED_BY_USER,
|
||||
):
|
||||
logger.info(
|
||||
"skill_trust: project skills trusted classification=%s repo=%s count=%d",
|
||||
classification,
|
||||
repo_key or "(no remote)",
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
# UNTRUSTED — need consent
|
||||
if not self._interactive or not sys.stdin.isatty():
|
||||
logger.warning(
|
||||
"skill_trust: skipping %d project-scope skill(s) from untrusted repo "
|
||||
"'%s' (non-interactive mode). "
|
||||
"To trust permanently run: hive skill trust %s",
|
||||
len(project_skills),
|
||||
repo_key,
|
||||
project_dir or ".",
|
||||
)
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=denied mode=headless",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted
|
||||
|
||||
# Interactive consent flow
|
||||
decision = self._run_consent_flow(project_skills, project_dir, repo_key)
|
||||
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=%s mode=interactive",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
decision,
|
||||
)
|
||||
|
||||
if decision == "session":
|
||||
return always_trusted + project_skills
|
||||
|
||||
if decision == "permanent":
|
||||
self._store.trust(repo_key, project_path=str(project_dir or ""))
|
||||
return always_trusted + project_skills
|
||||
|
||||
# denied
|
||||
return always_trusted
|
||||
|
||||
def _run_consent_flow(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
) -> str:
|
||||
"""Show the security notice (once) and consent prompt.
|
||||
Return 'session' | 'permanent' | 'denied'."""
|
||||
from framework.credentials.setup import Colors
|
||||
|
||||
if not sys.stdout.isatty():
|
||||
Colors.disable()
|
||||
|
||||
self._maybe_show_security_notice(Colors)
|
||||
self._print_consent_prompt(project_skills, project_dir, repo_key, Colors)
|
||||
return self._prompt_consent(Colors)
|
||||
|
||||
def _maybe_show_security_notice(self, Colors) -> None: # noqa: N803
|
||||
"""Show the one-time security notice if not already shown (NFR-5)."""
|
||||
if _NOTICE_SENTINEL_PATH.exists():
|
||||
return
|
||||
self._print("")
|
||||
self._print(
|
||||
f"{Colors.YELLOW}Security notice:{Colors.NC} Skills inject instructions "
|
||||
"into the agent's system prompt."
|
||||
)
|
||||
self._print(
|
||||
" Only load skills from sources you trust. "
|
||||
"Registry skills at tier 'verified' or 'official' have been audited."
|
||||
)
|
||||
self._print("")
|
||||
try:
|
||||
_NOTICE_SENTINEL_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
_NOTICE_SENTINEL_PATH.touch()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _print_consent_prompt(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
Colors, # noqa: N803
|
||||
) -> None:
|
||||
p = self._print
|
||||
p("")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p(f"{Colors.BOLD} SKILL TRUST REQUIRED{Colors.NC}")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p("")
|
||||
proj_label = str(project_dir) if project_dir else "this project"
|
||||
p(
|
||||
f" The project at {Colors.CYAN}{proj_label}{Colors.NC} wants to load "
|
||||
f"{len(project_skills)} skill(s)"
|
||||
)
|
||||
p(" that will inject instructions into the agent's system prompt.")
|
||||
if repo_key:
|
||||
p(f" Source: {Colors.BOLD}{repo_key}{Colors.NC}")
|
||||
p("")
|
||||
p(" Skills requesting access:")
|
||||
for skill in project_skills:
|
||||
p(f" {Colors.CYAN}•{Colors.NC} {Colors.BOLD}{skill.name}{Colors.NC}")
|
||||
p(f' "{skill.description}"')
|
||||
p(f" {Colors.DIM}{skill.location}{Colors.NC}")
|
||||
p("")
|
||||
p(" Options:")
|
||||
p(f" {Colors.CYAN}1){Colors.NC} Trust this session only")
|
||||
p(f" {Colors.CYAN}2){Colors.NC} Trust permanently — remember for future runs")
|
||||
p(
|
||||
f" {Colors.DIM}3) Deny"
|
||||
f" — skip all project-scope skills from this repo{Colors.NC}"
|
||||
)
|
||||
p(f"{Colors.YELLOW}{'─' * 60}{Colors.NC}")
|
||||
|
||||
def _prompt_consent(self, Colors) -> str: # noqa: N803
|
||||
"""Prompt until a valid choice is entered. Returns 'session'|'permanent'|'denied'."""
|
||||
mapping = {"1": "session", "2": "permanent", "3": "denied"}
|
||||
while True:
|
||||
try:
|
||||
choice = self._input("Select option (1-3): ").strip()
|
||||
if choice in mapping:
|
||||
return mapping[choice]
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return "denied"
|
||||
self._print(f"{Colors.RED}Invalid choice. Enter 1, 2, or 3.{Colors.NC}")
|
||||
@@ -0,0 +1,178 @@
|
||||
"""Strict SKILL.md validation for contributor tooling (hive skill validate).
|
||||
|
||||
Unlike the lenient parser used at runtime, this module applies hard-error rules
|
||||
that match the Agent Skills specification exactly. Intended for contributor
|
||||
tooling, CI gates, and hive skill doctor.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import _MAX_NAME_LENGTH
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of a strict SKILL.md validation run."""
|
||||
|
||||
passed: bool
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def validate_strict(path: Path) -> ValidationResult:
|
||||
"""Run all strict checks against a SKILL.md file.
|
||||
|
||||
Applies hard-error rules that go beyond the lenient runtime parser:
|
||||
- name must be explicit (no directory-name fallback)
|
||||
- YAML must parse without fixup
|
||||
- name/directory mismatch is an error, not a warning
|
||||
- empty body is an error
|
||||
- scripts must be executable
|
||||
|
||||
Args:
|
||||
path: Path to the SKILL.md file to validate.
|
||||
|
||||
Returns:
|
||||
ValidationResult with passed=True if no errors, plus any warnings.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
# 1. File exists and is readable
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
return ValidationResult(passed=False, errors=[f"File not found: {path}"])
|
||||
except PermissionError:
|
||||
return ValidationResult(passed=False, errors=[f"Permission denied reading: {path}"])
|
||||
except OSError as exc:
|
||||
return ValidationResult(passed=False, errors=[f"Cannot read file: {exc}"])
|
||||
|
||||
# 2. File not empty
|
||||
if not content.strip():
|
||||
return ValidationResult(passed=False, errors=["File is empty."])
|
||||
|
||||
# 3. YAML frontmatter present
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
errors=["Missing YAML frontmatter — wrap frontmatter with --- delimiters."],
|
||||
)
|
||||
|
||||
raw_yaml = parts[1].strip()
|
||||
body = parts[2].strip()
|
||||
|
||||
if not raw_yaml:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
errors=["Frontmatter delimiters present but YAML block is empty."],
|
||||
)
|
||||
|
||||
# 4. YAML parses WITHOUT fixup (strict: unquoted colons are an error)
|
||||
import yaml
|
||||
|
||||
frontmatter: dict | None = None
|
||||
try:
|
||||
frontmatter = yaml.safe_load(raw_yaml)
|
||||
except yaml.YAMLError as exc:
|
||||
errors.append(
|
||||
f"YAML parse error: {exc}. "
|
||||
'Wrap values containing colons in quotes, e.g. description: "Use for: research".'
|
||||
)
|
||||
return ValidationResult(passed=False, errors=errors, warnings=warnings)
|
||||
|
||||
if not isinstance(frontmatter, dict):
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
errors=["Frontmatter is not a YAML key-value mapping."],
|
||||
)
|
||||
|
||||
# 5. description present and non-empty
|
||||
description = frontmatter.get("description")
|
||||
if not description or not str(description).strip():
|
||||
errors.append("Missing required field: 'description' must be present and non-empty.")
|
||||
|
||||
# 6. name present and non-empty (no directory-name fallback in strict mode)
|
||||
name = frontmatter.get("name")
|
||||
if not name or not str(name).strip():
|
||||
errors.append(
|
||||
"Missing required field: 'name' must be present. "
|
||||
"Add 'name: your-skill-name' to the frontmatter."
|
||||
)
|
||||
else:
|
||||
name = str(name).strip()
|
||||
parent_dir_name = path.parent.name
|
||||
|
||||
# 7. name length <= 64 chars
|
||||
if len(name) > _MAX_NAME_LENGTH:
|
||||
errors.append(
|
||||
f"Skill name '{name}' is {len(name)} characters — "
|
||||
f"maximum is {_MAX_NAME_LENGTH}. Shorten the name."
|
||||
)
|
||||
|
||||
# 8. name matches parent directory (dot-namespace prefix allowed: hive.X with dir X)
|
||||
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
|
||||
errors.append(
|
||||
f"Name '{name}' does not match directory '{parent_dir_name}'. "
|
||||
f"Rename the directory to '{name}' or set name to '{parent_dir_name}'."
|
||||
)
|
||||
|
||||
# 9. body non-empty
|
||||
if not body:
|
||||
errors.append(
|
||||
"Skill body (instructions) is empty. "
|
||||
"Add markdown instructions after the closing --- delimiter."
|
||||
)
|
||||
|
||||
# 10. license present — warning only
|
||||
if not frontmatter.get("license"):
|
||||
warnings.append("No 'license' field — consider adding a license (e.g. MIT, Apache-2.0).")
|
||||
|
||||
# 11. Scripts in scripts/ exist and are executable (POSIX only —
|
||||
# Windows does not use POSIX permission bits)
|
||||
base_dir = path.parent
|
||||
scripts_dir = base_dir / "scripts"
|
||||
if scripts_dir.is_dir() and os.name != "nt":
|
||||
for script_path in sorted(scripts_dir.iterdir()):
|
||||
if script_path.is_file():
|
||||
if not (script_path.stat().st_mode & (stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)):
|
||||
errors.append(
|
||||
f"Script not executable: {script_path.name}. Run: chmod +x {script_path}"
|
||||
)
|
||||
|
||||
# 12. allowed-tools entries are non-empty strings — warning if malformed
|
||||
allowed_tools = frontmatter.get("allowed-tools")
|
||||
if allowed_tools is not None:
|
||||
if not isinstance(allowed_tools, list):
|
||||
warnings.append("'allowed-tools' should be a list of strings.")
|
||||
else:
|
||||
for tool in allowed_tools:
|
||||
if not isinstance(tool, str) or not tool.strip():
|
||||
warnings.append(f"'allowed-tools' entry {tool!r} is not a non-empty string.")
|
||||
|
||||
# 13. compatibility is a list of strings — error if malformed
|
||||
compatibility = frontmatter.get("compatibility")
|
||||
if compatibility is not None:
|
||||
if not isinstance(compatibility, list):
|
||||
errors.append("'compatibility' must be a list of strings.")
|
||||
else:
|
||||
for item in compatibility:
|
||||
if not isinstance(item, str):
|
||||
errors.append(f"'compatibility' entry {item!r} is not a string.")
|
||||
|
||||
# 14. metadata is a dict — error if malformed
|
||||
metadata = frontmatter.get("metadata")
|
||||
if metadata is not None and not isinstance(metadata, dict):
|
||||
errors.append("'metadata' must be a YAML mapping (dict), not a scalar or list.")
|
||||
|
||||
return ValidationResult(
|
||||
passed=len(errors) == 0,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Storage backends for runtime data."""
|
||||
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
__all__ = ["FileStorage", "FileConversationStore"]
|
||||
__all__ = ["ConcurrentStorage", "FileConversationStore"]
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
"""
|
||||
File-based storage backend for runtime data.
|
||||
|
||||
DEPRECATED: This storage backend is deprecated for new sessions.
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
This module is kept for backward compatibility with old run data only.
|
||||
|
||||
Uses Pydantic's built-in serialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
|
||||
class FileStorage:
|
||||
"""
|
||||
DEPRECATED: File-based storage for old runs only.
|
||||
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
This class is kept for backward compatibility with old run data.
|
||||
|
||||
Old directory structure (deprecated):
|
||||
{base_path}/
|
||||
runs/ # DEPRECATED - no longer written
|
||||
{run_id}.json
|
||||
summaries/ # DEPRECATED - no longer written
|
||||
{run_id}.json
|
||||
indexes/ # DEPRECATED - no longer written or read
|
||||
by_goal/
|
||||
{goal_id}.json
|
||||
by_status/
|
||||
{status}.json
|
||||
by_node/
|
||||
{node_id}.json
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: str | Path):
|
||||
self.base_path = Path(base_path)
|
||||
self._ensure_dirs()
|
||||
|
||||
def _ensure_dirs(self) -> None:
|
||||
"""Create directory structure if it doesn't exist.
|
||||
|
||||
DEPRECATED: All directories (runs/, summaries/, indexes/) are deprecated.
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
This method is now a no-op. Tests should not rely on this.
|
||||
"""
|
||||
# No-op: do not create deprecated directories
|
||||
pass
|
||||
|
||||
def _validate_key(self, key: str) -> None:
|
||||
"""
|
||||
Validate key to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
key: The key to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If key contains path traversal or dangerous patterns
|
||||
"""
|
||||
if not key or key.strip() == "":
|
||||
raise ValueError("Key cannot be empty")
|
||||
|
||||
# Block path separators
|
||||
if "/" in key or "\\" in key:
|
||||
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
|
||||
|
||||
# Block parent directory references
|
||||
if ".." in key or key.startswith("."):
|
||||
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
|
||||
|
||||
# Block absolute paths
|
||||
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
|
||||
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
|
||||
|
||||
# Block null bytes (Unix path injection)
|
||||
if "\x00" in key:
|
||||
raise ValueError("Invalid key format: null bytes not allowed")
|
||||
|
||||
# Block other dangerous special characters
|
||||
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
|
||||
if any(char in key for char in dangerous_chars):
|
||||
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
|
||||
|
||||
# === RUN OPERATIONS ===
|
||||
|
||||
def save_run(self, run: Run) -> None:
|
||||
"""Save a run to storage.
|
||||
|
||||
DEPRECATED: This method is now a no-op.
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
Tests should not rely on FileStorage - use unified session storage instead.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.save_run() is deprecated. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json. "
|
||||
"This write has been skipped.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# No-op: do not write to deprecated locations
|
||||
|
||||
def load_run(self, run_id: str) -> Run | None:
|
||||
"""Load a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
if not run_path.exists():
|
||||
return None
|
||||
with open(run_path, encoding="utf-8") as f:
|
||||
return Run.model_validate_json(f.read())
|
||||
|
||||
def load_summary(self, run_id: str) -> RunSummary | None:
|
||||
"""Load just the summary (faster than full run)."""
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
if not summary_path.exists():
|
||||
# Fall back to computing from full run
|
||||
run = self.load_run(run_id)
|
||||
if run:
|
||||
return RunSummary.from_run(run)
|
||||
return None
|
||||
|
||||
with open(summary_path, encoding="utf-8") as f:
|
||||
return RunSummary.model_validate_json(f.read())
|
||||
|
||||
def delete_run(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
|
||||
if not run_path.exists():
|
||||
return False
|
||||
|
||||
# Load run to get index keys
|
||||
run = self.load_run(run_id)
|
||||
if run:
|
||||
self._remove_from_index("by_goal", run.goal_id, run_id)
|
||||
self._remove_from_index("by_status", run.status.value, run_id)
|
||||
for node_id in run.metrics.nodes_executed:
|
||||
self._remove_from_index("by_node", node_id, run_id)
|
||||
|
||||
run_path.unlink()
|
||||
if summary_path.exists():
|
||||
summary_path.unlink()
|
||||
|
||||
return True
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_runs_by_goal(self, goal_id: str) -> list[str]:
|
||||
"""Get all run IDs for a goal.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns old run IDs from deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.get_runs_by_goal() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._get_index("by_goal", goal_id)
|
||||
|
||||
def get_runs_by_status(self, status: str | RunStatus) -> list[str]:
|
||||
"""Get all run IDs with a status.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns old run IDs from deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.get_runs_by_status() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if isinstance(status, RunStatus):
|
||||
status = status.value
|
||||
return self._get_index("by_status", status)
|
||||
|
||||
def get_runs_by_node(self, node_id: str) -> list[str]:
|
||||
"""Get all run IDs that executed a node.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns old run IDs from deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.get_runs_by_node() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._get_index("by_node", node_id)
|
||||
|
||||
def list_all_runs(self) -> list[str]:
|
||||
"""List all run IDs."""
|
||||
runs_dir = self.base_path / "runs"
|
||||
return [f.stem for f in runs_dir.glob("*.json")]
|
||||
|
||||
def list_all_goals(self) -> list[str]:
|
||||
"""List all goal IDs that have runs.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns goals from old run IDs in deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.list_all_goals() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
goals_dir = self.base_path / "indexes" / "by_goal"
|
||||
if not goals_dir.exists():
|
||||
return []
|
||||
return [f.stem for f in goals_dir.glob("*.json")]
|
||||
|
||||
# === INDEX OPERATIONS ===
|
||||
|
||||
def _get_index(self, index_type: str, key: str) -> list[str]:
|
||||
"""Get values from an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
with open(index_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _add_to_index(self, index_type: str, key: str, value: str) -> None:
|
||||
"""Add a value to an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value not in values:
|
||||
values.append(value)
|
||||
with atomic_write(index_path) as f:
|
||||
json.dump(values, f, indent=2)
|
||||
|
||||
def _remove_from_index(self, index_type: str, key: str, value: str) -> None:
|
||||
"""Remove a value from an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value in values:
|
||||
values.remove(value)
|
||||
with atomic_write(index_path) as f:
|
||||
json.dump(values, f, indent=2)
|
||||
|
||||
# === UTILITY ===
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get storage statistics."""
|
||||
return {
|
||||
"total_runs": len(self.list_all_runs()),
|
||||
"total_goals": len(self.list_all_goals()),
|
||||
"storage_path": str(self.base_path),
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Concurrent Storage - Thread-safe storage backend with file locking.
|
||||
|
||||
Wraps FileStorage with:
|
||||
Provides:
|
||||
- Async file locking for atomic writes
|
||||
- Write batching for performance
|
||||
- Read caching for concurrent access
|
||||
@@ -16,8 +16,8 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.schemas.run import Run, RunSummary
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,7 +41,6 @@ class ConcurrentStorage:
|
||||
- Async file locking to prevent concurrent write corruption
|
||||
- Write batching to reduce I/O overhead
|
||||
- Read caching for frequently accessed data
|
||||
- Compatible API with FileStorage
|
||||
|
||||
Example:
|
||||
storage = ConcurrentStorage("/path/to/storage")
|
||||
@@ -75,7 +74,6 @@ class ConcurrentStorage:
|
||||
max_locks: Maximum number of active file locks to track strongly
|
||||
"""
|
||||
self.base_path = Path(base_path)
|
||||
self._base_storage = FileStorage(base_path)
|
||||
|
||||
# Caching
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
@@ -157,6 +155,93 @@ class ConcurrentStorage:
|
||||
|
||||
return lock
|
||||
|
||||
# === KEY VALIDATION ===
|
||||
|
||||
@staticmethod
|
||||
def _validate_key(key: str) -> None:
|
||||
"""Validate key to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
key: The key to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If key contains path traversal or dangerous patterns
|
||||
"""
|
||||
if not key or key.strip() == "":
|
||||
raise ValueError("Key cannot be empty")
|
||||
|
||||
if "/" in key or "\\" in key:
|
||||
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
|
||||
|
||||
if ".." in key or key.startswith("."):
|
||||
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
|
||||
|
||||
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
|
||||
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
|
||||
|
||||
if "\x00" in key:
|
||||
raise ValueError("Invalid key format: null bytes not allowed")
|
||||
|
||||
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
|
||||
if any(char in key for char in dangerous_chars):
|
||||
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
|
||||
|
||||
# === FILE OPERATIONS (formerly in FileStorage) ===
|
||||
|
||||
def _save_run_sync(self, run: Run) -> None:
|
||||
"""Persist a run to disk as ``runs/{run_id}.json``.
|
||||
|
||||
Uses an atomic write (temp-file + rename) so a mid-write crash
|
||||
never leaves a partially written file on disk.
|
||||
"""
|
||||
self._validate_key(run.id)
|
||||
runs_dir = self.base_path / "runs"
|
||||
runs_dir.mkdir(parents=True, exist_ok=True)
|
||||
run_path = runs_dir / f"{run.id}.json"
|
||||
with atomic_write(run_path) as f:
|
||||
f.write(run.model_dump_json(indent=2))
|
||||
|
||||
def _load_run_sync(self, run_id: str) -> Run | None:
|
||||
"""Load a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
if not run_path.exists():
|
||||
return None
|
||||
with open(run_path, encoding="utf-8") as f:
|
||||
return Run.model_validate_json(f.read())
|
||||
|
||||
def _load_summary_sync(self, run_id: str) -> RunSummary | None:
|
||||
"""Load just the summary (faster than full run)."""
|
||||
self._validate_key(run_id)
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
if not summary_path.exists():
|
||||
run = self._load_run_sync(run_id)
|
||||
if run:
|
||||
return RunSummary.from_run(run)
|
||||
return None
|
||||
with open(summary_path, encoding="utf-8") as f:
|
||||
return RunSummary.model_validate_json(f.read())
|
||||
|
||||
def _delete_run_sync(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
|
||||
if not run_path.exists():
|
||||
return False
|
||||
|
||||
run_path.unlink()
|
||||
if summary_path.exists():
|
||||
summary_path.unlink()
|
||||
|
||||
return True
|
||||
|
||||
def _list_all_runs_sync(self) -> list[str]:
|
||||
"""List all run IDs."""
|
||||
runs_dir = self.base_path / "runs"
|
||||
if not runs_dir.exists():
|
||||
return []
|
||||
return [f.stem for f in runs_dir.glob("*.json")]
|
||||
|
||||
# === RUN OPERATIONS (Async, Thread-Safe) ===
|
||||
|
||||
async def save_run(self, run: Run, immediate: bool = False) -> None:
|
||||
@@ -180,40 +265,17 @@ class ConcurrentStorage:
|
||||
await self._write_queue.put(("run", run))
|
||||
|
||||
async def _save_run_locked(self, run: Run) -> None:
|
||||
"""Save a run with file locking, including index locks."""
|
||||
"""Save a run with file locking."""
|
||||
lock_key = f"run:{run.id}"
|
||||
|
||||
# Helper to get lock
|
||||
async def get_lock(k):
|
||||
return await self._get_lock(k)
|
||||
|
||||
# Acquire main lock
|
||||
run_lock = await get_lock(lock_key)
|
||||
run_lock = await self._get_lock(lock_key)
|
||||
|
||||
async with run_lock:
|
||||
# 2. Acquire index locks
|
||||
index_lock_keys = [
|
||||
f"index:by_goal:{run.goal_id}",
|
||||
f"index:by_status:{run.status.value}",
|
||||
]
|
||||
for node_id in run.metrics.nodes_executed:
|
||||
index_lock_keys.append(f"index:by_node:{node_id}")
|
||||
|
||||
# Collect index locks
|
||||
index_locks = [await get_lock(k) for k in index_lock_keys]
|
||||
|
||||
# Recursive acquisition
|
||||
async def with_locks(locks, callback):
|
||||
if not locks:
|
||||
return await callback()
|
||||
async with locks[0]:
|
||||
return await with_locks(locks[1:], callback)
|
||||
|
||||
async def perform_save():
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._base_storage.save_run, run)
|
||||
await loop.run_in_executor(None, self._save_run_sync, run)
|
||||
|
||||
await with_locks(index_locks, perform_save)
|
||||
await perform_save()
|
||||
|
||||
async def load_run(self, run_id: str, use_cache: bool = True) -> Run | None:
|
||||
"""
|
||||
@@ -225,7 +287,11 @@ class ConcurrentStorage:
|
||||
|
||||
Returns:
|
||||
Run object or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
if use_cache:
|
||||
cache_key = f"run:{run_id}"
|
||||
cached = self._cache.get(cache_key)
|
||||
@@ -240,7 +306,7 @@ class ConcurrentStorage:
|
||||
lock_key = f"run:{run_id}"
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
run = await loop.run_in_executor(None, self._base_storage.load_run, run_id)
|
||||
run = await loop.run_in_executor(None, self._load_run_sync, run_id)
|
||||
|
||||
# Update cache
|
||||
if run:
|
||||
@@ -249,7 +315,12 @@ class ConcurrentStorage:
|
||||
return run
|
||||
|
||||
async def load_summary(self, run_id: str, use_cache: bool = True) -> RunSummary | None:
|
||||
"""Load just the summary (faster than full run)."""
|
||||
"""Load just the summary (faster than full run).
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
cache_key = f"summary:{run_id}"
|
||||
|
||||
# Check cache
|
||||
@@ -262,7 +333,7 @@ class ConcurrentStorage:
|
||||
lock_key = f"summary:{run_id}"
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
summary = await loop.run_in_executor(None, self._base_storage.load_summary, run_id)
|
||||
summary = await loop.run_in_executor(None, self._load_summary_sync, run_id)
|
||||
|
||||
# Update cache
|
||||
if summary:
|
||||
@@ -271,11 +342,16 @@ class ConcurrentStorage:
|
||||
return summary
|
||||
|
||||
async def delete_run(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
"""Delete a run from storage.
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
lock_key = f"run:{run_id}"
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, self._base_storage.delete_run, run_id)
|
||||
result = await loop.run_in_executor(None, self._delete_run_sync, run_id)
|
||||
|
||||
# Clear cache
|
||||
self._cache.pop(f"run:{run_id}", None)
|
||||
@@ -283,37 +359,10 @@ class ConcurrentStorage:
|
||||
|
||||
return result
|
||||
|
||||
# === QUERY OPERATIONS (Async, with Locking) ===
|
||||
|
||||
async def get_runs_by_goal(self, goal_id: str) -> list[str]:
|
||||
"""Get all run IDs for a goal."""
|
||||
async with await self._get_lock(f"index:by_goal:{goal_id}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_goal, goal_id)
|
||||
|
||||
async def get_runs_by_status(self, status: str | RunStatus) -> list[str]:
|
||||
"""Get all run IDs with a status."""
|
||||
if isinstance(status, RunStatus):
|
||||
status = status.value
|
||||
async with await self._get_lock(f"index:by_status:{status}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_status, status)
|
||||
|
||||
async def get_runs_by_node(self, node_id: str) -> list[str]:
|
||||
"""Get all run IDs that executed a node."""
|
||||
async with await self._get_lock(f"index:by_node:{node_id}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_node, node_id)
|
||||
|
||||
async def list_all_runs(self) -> list[str]:
|
||||
"""List all run IDs."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.list_all_runs)
|
||||
|
||||
async def list_all_goals(self) -> list[str]:
|
||||
"""List all goal IDs that have runs."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.list_all_goals)
|
||||
return await loop.run_in_executor(None, self._list_all_runs_sync)
|
||||
|
||||
# === BATCH OPERATIONS ===
|
||||
|
||||
@@ -411,10 +460,11 @@ class ConcurrentStorage:
|
||||
async def get_stats(self) -> dict:
|
||||
"""Get storage statistics."""
|
||||
loop = asyncio.get_event_loop()
|
||||
base_stats = await loop.run_in_executor(None, self._base_storage.get_stats)
|
||||
all_runs = await loop.run_in_executor(None, self._list_all_runs_sync)
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"total_runs": len(all_runs),
|
||||
"storage_path": str(self.base_path),
|
||||
"cache": self.get_cache_stats(),
|
||||
"pending_writes": self._write_queue.qsize(),
|
||||
"running": self._running,
|
||||
@@ -423,10 +473,21 @@ class ConcurrentStorage:
|
||||
# === SYNC API (for backward compatibility) ===
|
||||
|
||||
def save_run_sync(self, run: Run) -> None:
|
||||
"""Synchronous save (uses base storage directly with lock)."""
|
||||
# Use threading lock for sync operations
|
||||
self._base_storage.save_run(run)
|
||||
"""Synchronous save — persists a run to disk immediately."""
|
||||
self._validate_key(run.id)
|
||||
# Invalidate summary cache since the run data is changing
|
||||
self._cache.pop(f"summary:{run.id}", None)
|
||||
|
||||
self._save_run_sync(run)
|
||||
|
||||
# Refresh run cache
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
|
||||
def load_run_sync(self, run_id: str) -> Run | None:
|
||||
"""Synchronous load (uses base storage directly)."""
|
||||
return self._base_storage.load_run(run_id)
|
||||
"""Synchronous load.
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
return self._load_run_sync(run_id)
|
||||
|
||||
@@ -62,8 +62,14 @@ class SessionStore:
|
||||
|
||||
Returns:
|
||||
Path to session directory
|
||||
|
||||
Raises:
|
||||
ValueError: If session_id resolves outside the sessions directory
|
||||
"""
|
||||
return self.sessions_dir / session_id
|
||||
resolved = (self.sessions_dir / session_id).resolve()
|
||||
if not resolved.is_relative_to(self.sessions_dir.resolve()):
|
||||
raise ValueError(f"Invalid session ID: {session_id}")
|
||||
return resolved
|
||||
|
||||
def get_state_path(self, session_id: str) -> Path:
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,9 @@ class DebugTool:
|
||||
|
||||
Args:
|
||||
test_storage: Storage for test and result data
|
||||
runtime_storage: Optional FileStorage for Runtime data
|
||||
runtime_storage: Optional storage backend for Runtime data.
|
||||
Must expose a synchronous ``load_run_sync(run_id)`` method
|
||||
(e.g. ``ConcurrentStorage``).
|
||||
"""
|
||||
self.test_storage = test_storage
|
||||
self.runtime_storage = runtime_storage
|
||||
@@ -233,7 +235,9 @@ class DebugTool:
|
||||
return {}
|
||||
|
||||
try:
|
||||
run = self.runtime_storage.load_run(run_id)
|
||||
# Use the synchronous loader — _get_runtime_data is not async
|
||||
# and ConcurrentStorage.load_run() is a coroutine.
|
||||
run = self.runtime_storage.load_run_sync(run_id)
|
||||
if not run:
|
||||
return {"error": f"Run {run_id} not found"}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
File-based storage backend for test data.
|
||||
|
||||
Follows the same pattern as framework/storage/backend.py (FileStorage),
|
||||
Follows the same pattern as framework/storage/concurrent.py (ConcurrentStorage),
|
||||
storing tests as JSON files with indexes for efficient querying.
|
||||
"""
|
||||
|
||||
|
||||
@@ -118,6 +118,8 @@ class QueenPhaseState:
|
||||
|
||||
# Default skill operational protocols — appended to every phase prompt
|
||||
protocols_prompt: str = ""
|
||||
# Community skills catalog (XML) — appended after protocols
|
||||
skills_catalog_prompt: str = ""
|
||||
|
||||
def get_current_tools(self) -> list:
|
||||
"""Return tools for the current phase."""
|
||||
@@ -144,6 +146,8 @@ class QueenPhaseState:
|
||||
|
||||
memory = format_for_injection()
|
||||
parts = [base]
|
||||
if self.skills_catalog_prompt:
|
||||
parts.append(self.skills_catalog_prompt)
|
||||
if self.protocols_prompt:
|
||||
parts.append(self.protocols_prompt)
|
||||
if memory:
|
||||
|
||||
@@ -49,7 +49,7 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
|
||||
"""
|
||||
from datetime import date, timedelta
|
||||
|
||||
from framework.agents.queen.queen_memory import read_episodic_memory
|
||||
from framework.agents.queen.queen_memory import format_memory_date, read_episodic_memory
|
||||
|
||||
days_back = max(1, min(days_back, 30))
|
||||
today = date.today()
|
||||
@@ -70,7 +70,7 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
|
||||
if not matched:
|
||||
continue
|
||||
content = "### ".join(matched)
|
||||
label = d.strftime("%B %-d, %Y")
|
||||
label = format_memory_date(d)
|
||||
if d == today:
|
||||
label = f"Today — {label}"
|
||||
entry = f"## {label}\n\n{content}"
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Routes, Route } from "react-router-dom";
|
||||
import Home from "./pages/home";
|
||||
import MyAgents from "./pages/my-agents";
|
||||
import Workspace from "./pages/workspace";
|
||||
import NotFound from "./pages/not-found";
|
||||
|
||||
function App() {
|
||||
return (
|
||||
@@ -9,6 +10,7 @@ function App() {
|
||||
<Route path="/" element={<Home />} />
|
||||
<Route path="/my-agents" element={<MyAgents />} />
|
||||
<Route path="/workspace" element={<Workspace />} />
|
||||
<Route path="*" element={<NotFound />} />
|
||||
</Routes>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -34,8 +34,8 @@ export const executionApi = {
|
||||
graph_id: graphId,
|
||||
}),
|
||||
|
||||
chat: (sessionId: string, message: string) =>
|
||||
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message }),
|
||||
chat: (sessionId: string, message: string, images?: { type: string; image_url: { url: string } }[]) =>
|
||||
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message, ...(images?.length ? { images } : {}) }),
|
||||
|
||||
/** Queue context for the queen without triggering an LLM response. */
|
||||
queenContext: (sessionId: string, message: string) =>
|
||||
|
||||
@@ -81,6 +81,10 @@ export const sessionsApi = {
|
||||
eventsHistory: (sessionId: string) =>
|
||||
api.get<{ events: AgentEvent[]; session_id: string }>(`/sessions/${sessionId}/events/history`),
|
||||
|
||||
/** Open the session's data folder in the OS file manager. */
|
||||
revealFolder: (sessionId: string) =>
|
||||
api.post<{ path: string }>(`/sessions/${sessionId}/reveal`),
|
||||
|
||||
/** List all queen sessions on disk — live + cold (post-restart). */
|
||||
history: () =>
|
||||
api.get<{ sessions: Array<{ session_id: string; cold: boolean; live: boolean; has_messages: boolean; created_at: number; agent_name?: string | null; agent_path?: string | null }> }>("/sessions/history"),
|
||||
|
||||
@@ -14,6 +14,8 @@ export interface LiveSession {
|
||||
intro_message?: string;
|
||||
/** Queen operating phase — "planning", "building", "staging", or "running" */
|
||||
queen_phase?: "planning" | "building" | "staging" | "running";
|
||||
/** Whether the queen's LLM supports image content in messages */
|
||||
queen_supports_images?: boolean;
|
||||
/** Present in 409 conflict responses when worker is still loading */
|
||||
loading?: boolean;
|
||||
}
|
||||
@@ -324,6 +326,7 @@ export type EventTypeName =
|
||||
| "node_retry"
|
||||
| "edge_traversed"
|
||||
| "context_compacted"
|
||||
| "context_usage_updated"
|
||||
| "webhook_received"
|
||||
| "custom"
|
||||
| "escalation_requested"
|
||||
|
||||
@@ -1,8 +1,32 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { Send, Square, Crown, Cpu, Check, Loader2 } from "lucide-react";
|
||||
import { memo, useState, useRef, useEffect, useMemo } from "react";
|
||||
import {
|
||||
Send,
|
||||
Square,
|
||||
Crown,
|
||||
Cpu,
|
||||
Check,
|
||||
Loader2,
|
||||
Paperclip,
|
||||
X,
|
||||
} from "lucide-react";
|
||||
|
||||
export interface ImageContent {
|
||||
type: "image_url";
|
||||
image_url: { url: string };
|
||||
}
|
||||
|
||||
export interface ContextUsageEntry {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
import QuestionWidget from "@/components/QuestionWidget";
|
||||
import MultiQuestionWidget from "@/components/MultiQuestionWidget";
|
||||
import ParallelSubagentBubble, {
|
||||
type SubagentGroup,
|
||||
} from "@/components/ParallelSubagentBubble";
|
||||
|
||||
export interface ChatMessage {
|
||||
id: string;
|
||||
@@ -10,7 +34,13 @@ export interface ChatMessage {
|
||||
agentColor: string;
|
||||
content: string;
|
||||
timestamp: string;
|
||||
type?: "system" | "agent" | "user" | "tool_status" | "worker_input_request" | "run_divider";
|
||||
type?:
|
||||
| "system"
|
||||
| "agent"
|
||||
| "user"
|
||||
| "tool_status"
|
||||
| "worker_input_request"
|
||||
| "run_divider";
|
||||
role?: "queen" | "worker";
|
||||
/** Which worker thread this message belongs to (worker agent name) */
|
||||
thread?: string;
|
||||
@@ -18,11 +48,17 @@ export interface ChatMessage {
|
||||
createdAt?: number;
|
||||
/** Queen phase active when this message was created */
|
||||
phase?: "planning" | "building" | "staging" | "running";
|
||||
/** Images attached to a user message */
|
||||
images?: ImageContent[];
|
||||
/** Backend node_id that produced this message — used for subagent grouping */
|
||||
nodeId?: string;
|
||||
/** Backend execution_id for this message */
|
||||
executionId?: string;
|
||||
}
|
||||
|
||||
interface ChatPanelProps {
|
||||
messages: ChatMessage[];
|
||||
onSend: (message: string, thread: string) => void;
|
||||
onSend: (message: string, thread: string, images?: ImageContent[]) => void;
|
||||
isWaiting?: boolean;
|
||||
/** When true a worker is thinking (not yet streaming) */
|
||||
isWorkerWaiting?: boolean;
|
||||
@@ -31,6 +67,8 @@ interface ChatPanelProps {
|
||||
activeThread: string;
|
||||
/** When true, the input is disabled (e.g. during loading) */
|
||||
disabled?: boolean;
|
||||
/** When false, the image attach button is hidden (model lacks vision support) */
|
||||
supportsImages?: boolean;
|
||||
/** Called when user clicks the stop button to cancel the queen's current turn */
|
||||
onCancel?: () => void;
|
||||
/** Pending question from ask_user — replaces textarea when present */
|
||||
@@ -38,7 +76,9 @@ interface ChatPanelProps {
|
||||
/** Options for the pending question */
|
||||
pendingOptions?: string[] | null;
|
||||
/** Multiple questions from ask_user_multiple */
|
||||
pendingQuestions?: { id: string; prompt: string; options?: string[] }[] | null;
|
||||
pendingQuestions?:
|
||||
| { id: string; prompt: string; options?: string[] }[]
|
||||
| null;
|
||||
/** Called when user submits an answer to the pending question */
|
||||
onQuestionSubmit?: (answer: string, isOther: boolean) => void;
|
||||
/** Called when user submits answers to multiple questions */
|
||||
@@ -47,6 +87,8 @@ interface ChatPanelProps {
|
||||
onQuestionDismiss?: () => void;
|
||||
/** Queen operating phase — shown as a tag on queen messages */
|
||||
queenPhase?: "planning" | "building" | "staging" | "running";
|
||||
/** Context window usage for queen and workers */
|
||||
contextUsage?: Record<string, ContextUsageEntry>;
|
||||
}
|
||||
|
||||
const queenColor = "hsl(45,95%,58%)";
|
||||
@@ -72,7 +114,8 @@ const TOOL_HEX = [
|
||||
|
||||
function toolHex(name: string): string {
|
||||
let hash = 0;
|
||||
for (let i = 0; i < name.length; i++) hash = (hash * 31 + name.charCodeAt(i)) | 0;
|
||||
for (let i = 0; i < name.length; i++)
|
||||
hash = (hash * 31 + name.charCodeAt(i)) | 0;
|
||||
return TOOL_HEX[Math.abs(hash) % TOOL_HEX.length];
|
||||
}
|
||||
|
||||
@@ -120,12 +163,18 @@ function ToolActivityRow({ content }: { content: string }) {
|
||||
<span
|
||||
key={`run-${p.name}`}
|
||||
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
|
||||
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
|
||||
style={{
|
||||
color: hex,
|
||||
backgroundColor: `${hex}18`,
|
||||
border: `1px solid ${hex}35`,
|
||||
}}
|
||||
>
|
||||
<Loader2 className="w-2.5 h-2.5 animate-spin" />
|
||||
{p.name}
|
||||
{p.count > 1 && (
|
||||
<span className="text-[10px] font-medium opacity-70">×{p.count}</span>
|
||||
<span className="text-[10px] font-medium opacity-70">
|
||||
×{p.count}
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
@@ -136,7 +185,11 @@ function ToolActivityRow({ content }: { content: string }) {
|
||||
<span
|
||||
key={`done-${p.name}`}
|
||||
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
|
||||
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
|
||||
style={{
|
||||
color: hex,
|
||||
backgroundColor: `${hex}18`,
|
||||
border: `1px solid ${hex}35`,
|
||||
}}
|
||||
>
|
||||
<Check className="w-2.5 h-2.5" />
|
||||
{p.name}
|
||||
@@ -151,109 +204,249 @@ function ToolActivityRow({ content }: { content: string }) {
|
||||
);
|
||||
}
|
||||
|
||||
const MessageBubble = memo(function MessageBubble({ msg, queenPhase }: { msg: ChatMessage; queenPhase?: "planning" | "building" | "staging" | "running" }) {
|
||||
const isUser = msg.type === "user";
|
||||
const isQueen = msg.role === "queen";
|
||||
const color = getColor(msg.agent, msg.role);
|
||||
const MessageBubble = memo(
|
||||
function MessageBubble({
|
||||
msg,
|
||||
queenPhase,
|
||||
}: {
|
||||
msg: ChatMessage;
|
||||
queenPhase?: "planning" | "building" | "staging" | "running";
|
||||
}) {
|
||||
const isUser = msg.type === "user";
|
||||
const isQueen = msg.role === "queen";
|
||||
const color = getColor(msg.agent, msg.role);
|
||||
|
||||
if (msg.type === "run_divider") {
|
||||
return (
|
||||
<div className="flex items-center gap-3 py-2 my-1">
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
|
||||
{msg.content}
|
||||
</span>
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (msg.type === "system") {
|
||||
return (
|
||||
<div className="flex justify-center py-1">
|
||||
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
|
||||
{msg.content}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (msg.type === "tool_status") {
|
||||
return <ToolActivityRow content={msg.content} />;
|
||||
}
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
<div className="flex justify-end">
|
||||
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
|
||||
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
|
||||
if (msg.type === "run_divider") {
|
||||
return (
|
||||
<div className="flex items-center gap-3 py-2 my-1">
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
|
||||
{msg.content}
|
||||
</span>
|
||||
<div className="flex-1 h-px bg-border/60" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
|
||||
style={{
|
||||
backgroundColor: `${color}18`,
|
||||
border: `1.5px solid ${color}35`,
|
||||
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
|
||||
}}
|
||||
>
|
||||
{isQueen ? (
|
||||
<Crown className="w-4 h-4" style={{ color }} />
|
||||
) : (
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color }} />
|
||||
)}
|
||||
</div>
|
||||
<div className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}>
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`} style={{ color }}>
|
||||
{msg.agent}
|
||||
</span>
|
||||
<span
|
||||
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
|
||||
isQueen ? "bg-primary/15 text-primary" : "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{isQueen
|
||||
? ((msg.phase ?? queenPhase) === "running"
|
||||
? "running"
|
||||
: (msg.phase ?? queenPhase) === "staging"
|
||||
? "staging"
|
||||
: (msg.phase ?? queenPhase) === "planning"
|
||||
? "planning"
|
||||
: "building")
|
||||
: "Worker"}
|
||||
if (msg.type === "system") {
|
||||
return (
|
||||
<div className="flex justify-center py-1">
|
||||
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
|
||||
{msg.content}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (msg.type === "tool_status") {
|
||||
return <ToolActivityRow content={msg.content} />;
|
||||
}
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
<div className="flex justify-end">
|
||||
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
|
||||
{msg.images && msg.images.length > 0 && (
|
||||
<div className="flex flex-wrap gap-2 mb-2">
|
||||
{msg.images.map((img, i) => (
|
||||
<img
|
||||
key={i}
|
||||
src={img.image_url.url}
|
||||
alt={`attachment ${i + 1}`}
|
||||
className="max-h-48 max-w-full rounded-lg object-contain"
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{msg.content && (
|
||||
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
|
||||
style={{
|
||||
backgroundColor: `${color}18`,
|
||||
border: `1.5px solid ${color}35`,
|
||||
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
|
||||
}}
|
||||
>
|
||||
{isQueen ? (
|
||||
<Crown className="w-4 h-4" style={{ color }} />
|
||||
) : (
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color }} />
|
||||
)}
|
||||
</div>
|
||||
<div
|
||||
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
|
||||
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
|
||||
}`}
|
||||
className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}
|
||||
>
|
||||
<MarkdownContent content={msg.content} />
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span
|
||||
className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`}
|
||||
style={{ color }}
|
||||
>
|
||||
{msg.agent}
|
||||
</span>
|
||||
<span
|
||||
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
|
||||
isQueen
|
||||
? "bg-primary/15 text-primary"
|
||||
: "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{isQueen
|
||||
? (msg.phase ?? queenPhase) === "running"
|
||||
? "running"
|
||||
: (msg.phase ?? queenPhase) === "staging"
|
||||
? "staging"
|
||||
: (msg.phase ?? queenPhase) === "planning"
|
||||
? "planning"
|
||||
: "building"
|
||||
: "Worker"}
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
|
||||
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
|
||||
}`}
|
||||
>
|
||||
<MarkdownContent content={msg.content} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}, (prev, next) => prev.msg.id === next.msg.id && prev.msg.content === next.msg.content && prev.msg.phase === next.msg.phase && prev.queenPhase === next.queenPhase);
|
||||
);
|
||||
},
|
||||
(prev, next) =>
|
||||
prev.msg.id === next.msg.id &&
|
||||
prev.msg.content === next.msg.content &&
|
||||
prev.msg.phase === next.msg.phase &&
|
||||
prev.queenPhase === next.queenPhase,
|
||||
);
|
||||
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase }: ChatPanelProps) {
|
||||
export default function ChatPanel({
|
||||
messages,
|
||||
onSend,
|
||||
isWaiting,
|
||||
isWorkerWaiting,
|
||||
isBusy,
|
||||
activeThread,
|
||||
disabled,
|
||||
onCancel,
|
||||
pendingQuestion,
|
||||
pendingOptions,
|
||||
pendingQuestions,
|
||||
onQuestionSubmit,
|
||||
onMultiQuestionSubmit,
|
||||
onQuestionDismiss,
|
||||
queenPhase,
|
||||
contextUsage,
|
||||
supportsImages = true,
|
||||
}: ChatPanelProps) {
|
||||
const [input, setInput] = useState("");
|
||||
const [pendingImages, setPendingImages] = useState<ImageContent[]>([]);
|
||||
const [readMap, setReadMap] = useState<Record<string, number>>({});
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const stickToBottom = useRef(true);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const threadMessages = messages.filter((m) => {
|
||||
if (m.type === "system" && !m.thread) return false;
|
||||
return m.thread === activeThread;
|
||||
if (m.thread !== activeThread) return false;
|
||||
// Hide queen messages whose content is whitespace-only — these are
|
||||
// tool-use-only turns that have no visible text. During live operation
|
||||
// tool pills provide context, but on resume the pills are gone so
|
||||
// the empty bubble is meaningless.
|
||||
if (m.role === "queen" && !m.type && (!m.content || !m.content.trim()))
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
|
||||
// Group subagent messages into parallel bubbles.
|
||||
// A subagent message has nodeId containing ":subagent:".
|
||||
// The run only ends on hard boundaries (user messages, run_dividers)
|
||||
// so interleaved queen/tool/system messages don't fragment the bubble.
|
||||
type RenderItem =
|
||||
| { kind: "message"; msg: ChatMessage }
|
||||
| { kind: "parallel"; groupId: string; groups: SubagentGroup[] };
|
||||
|
||||
const renderItems = useMemo<RenderItem[]>(() => {
|
||||
const items: RenderItem[] = [];
|
||||
let i = 0;
|
||||
while (i < threadMessages.length) {
|
||||
const msg = threadMessages[i];
|
||||
const isSubagent = msg.nodeId?.includes(":subagent:");
|
||||
if (!isSubagent) {
|
||||
items.push({ kind: "message", msg });
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Start a subagent run. Collect all subagent messages, allowing
|
||||
// non-subagent messages in between (they render as normal items
|
||||
// before the bubble). Only break on hard boundaries.
|
||||
const subagentMsgs: ChatMessage[] = [];
|
||||
const interleaved: { idx: number; msg: ChatMessage }[] = [];
|
||||
const firstId = msg.id;
|
||||
|
||||
while (i < threadMessages.length) {
|
||||
const m = threadMessages[i];
|
||||
const isSa = m.nodeId?.includes(":subagent:");
|
||||
|
||||
if (isSa) {
|
||||
subagentMsgs.push(m);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Hard boundary — stop the run
|
||||
if (m.type === "user" || m.type === "run_divider") break;
|
||||
|
||||
// Worker message from a non-subagent node means the graph has
|
||||
// moved on to the next stage. Close the bubble even if some
|
||||
// subagents are still streaming in the background.
|
||||
if (m.role === "worker" && m.nodeId && !m.nodeId.includes(":subagent:"))
|
||||
break;
|
||||
|
||||
// Soft interruption (queen output, system, tool_status without
|
||||
// nodeId) — render it normally but keep the subagent run going
|
||||
interleaved.push({ idx: items.length + interleaved.length, msg: m });
|
||||
i++;
|
||||
}
|
||||
|
||||
// Emit interleaved messages first (before the bubble)
|
||||
for (const { msg: im } of interleaved) {
|
||||
items.push({ kind: "message", msg: im });
|
||||
}
|
||||
|
||||
// Build the single parallel bubble from all collected subagent msgs
|
||||
if (subagentMsgs.length > 0) {
|
||||
const byNode = new Map<string, ChatMessage[]>();
|
||||
for (const m of subagentMsgs) {
|
||||
const nid = m.nodeId!;
|
||||
if (!byNode.has(nid)) byNode.set(nid, []);
|
||||
byNode.get(nid)!.push(m);
|
||||
}
|
||||
const groups: SubagentGroup[] = [];
|
||||
for (const [nodeId, msgs] of byNode) {
|
||||
groups.push({
|
||||
nodeId,
|
||||
messages: msgs,
|
||||
contextUsage: contextUsage?.[nodeId],
|
||||
});
|
||||
}
|
||||
items.push({ kind: "parallel", groupId: `par-${firstId}`, groups });
|
||||
}
|
||||
}
|
||||
return items;
|
||||
}, [threadMessages, contextUsage]);
|
||||
|
||||
// Mark current thread as read
|
||||
useEffect(() => {
|
||||
const count = messages.filter((m) => m.thread === activeThread).length;
|
||||
@@ -284,26 +477,64 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!input.trim()) return;
|
||||
onSend(input.trim(), activeThread);
|
||||
if (!input.trim() && pendingImages.length === 0) return;
|
||||
onSend(
|
||||
input.trim(),
|
||||
activeThread,
|
||||
pendingImages.length > 0 ? pendingImages : undefined,
|
||||
);
|
||||
setInput("");
|
||||
setPendingImages([]);
|
||||
if (textareaRef.current) textareaRef.current.style.height = "auto";
|
||||
};
|
||||
|
||||
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const files = Array.from(e.target.files ?? []);
|
||||
if (files.length === 0) return;
|
||||
files.forEach((file) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (ev) => {
|
||||
const url = ev.target?.result as string;
|
||||
setPendingImages((prev) => [
|
||||
...prev,
|
||||
{ type: "image_url", image_url: { url } },
|
||||
]);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
// Reset so the same file can be re-selected
|
||||
e.target.value = "";
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full min-w-0">
|
||||
{/* Compact sub-header */}
|
||||
<div className="px-5 pt-4 pb-2 flex items-center gap-2">
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">Conversation</p>
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">
|
||||
Conversation
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Messages */}
|
||||
<div ref={scrollRef} onScroll={handleScroll} className="flex-1 overflow-auto px-5 py-4 space-y-3">
|
||||
{threadMessages.map((msg) => (
|
||||
<div key={msg.id}>
|
||||
<MessageBubble msg={msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
))}
|
||||
<div
|
||||
ref={scrollRef}
|
||||
onScroll={handleScroll}
|
||||
className="flex-1 overflow-auto px-5 py-4 space-y-3"
|
||||
>
|
||||
{renderItems.map((item) =>
|
||||
item.kind === "parallel" ? (
|
||||
<div key={item.groupId}>
|
||||
<ParallelSubagentBubble
|
||||
groupId={item.groupId}
|
||||
groups={item.groups}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div key={item.msg.id}>
|
||||
<MessageBubble msg={item.msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
),
|
||||
)}
|
||||
|
||||
{/* Show typing indicator while waiting for first queen response (disabled + empty chat) */}
|
||||
{(isWaiting || (disabled && threadMessages.length === 0)) && (
|
||||
@@ -320,9 +551,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
</div>
|
||||
<div className="border border-primary/20 bg-primary/5 rounded-2xl rounded-tl-md px-4 py-3">
|
||||
<div className="flex gap-1.5">
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "0ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "150ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "300ms" }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -340,9 +580,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
</div>
|
||||
<div className="bg-muted/60 rounded-2xl rounded-tl-md px-4 py-3">
|
||||
<div className="flex gap-1.5">
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "0ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "150ms" }}
|
||||
/>
|
||||
<span
|
||||
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
|
||||
style={{ animationDelay: "300ms" }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -350,8 +599,99 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
<div ref={bottomRef} />
|
||||
</div>
|
||||
|
||||
{/* Context window usage bar — sits between messages and input */}
|
||||
{(() => {
|
||||
if (!contextUsage) return null;
|
||||
const queenUsage = contextUsage["__queen__"];
|
||||
const workerEntries = Object.entries(contextUsage).filter(
|
||||
([k]) => k !== "__queen__",
|
||||
);
|
||||
const workerUsage =
|
||||
workerEntries.length > 0
|
||||
? workerEntries.reduce(
|
||||
(best, [, v]) => (v.usagePct > best.usagePct ? v : best),
|
||||
workerEntries[0][1],
|
||||
)
|
||||
: undefined;
|
||||
if (!queenUsage && !workerUsage) return null;
|
||||
return (
|
||||
<div className="flex items-center gap-3 mx-4 px-3 py-1 rounded-lg bg-muted/30 border border-border/20 group/ctx flex-shrink-0">
|
||||
{queenUsage && (
|
||||
<div
|
||||
className="flex items-center gap-2 flex-1 min-w-0"
|
||||
title={`Queen: ${(queenUsage.estimatedTokens / 1000).toFixed(1)}k / ${(queenUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${queenUsage.messageCount} messages`}
|
||||
>
|
||||
<Crown
|
||||
className="w-3 h-3 flex-shrink-0"
|
||||
style={{ color: "hsl(45,95%,58%)" }}
|
||||
/>
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(queenUsage.usagePct, 100)}%`,
|
||||
backgroundColor:
|
||||
queenUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: queenUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">
|
||||
{queenUsage.usagePct}%
|
||||
</span>
|
||||
<span className="hidden group-hover/ctx:inline">
|
||||
{(queenUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
|
||||
{(queenUsage.maxTokens / 1000).toFixed(0)}k
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{workerUsage && (
|
||||
<div
|
||||
className="flex items-center gap-2 flex-1 min-w-0"
|
||||
title={`Worker: ${(workerUsage.estimatedTokens / 1000).toFixed(1)}k / ${(workerUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${workerUsage.messageCount} messages`}
|
||||
>
|
||||
<Cpu
|
||||
className="w-3 h-3 flex-shrink-0"
|
||||
style={{ color: "hsl(220,60%,55%)" }}
|
||||
/>
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(workerUsage.usagePct, 100)}%`,
|
||||
backgroundColor:
|
||||
workerUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: workerUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(220,60%,55%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">
|
||||
{workerUsage.usagePct}%
|
||||
</span>
|
||||
<span className="hidden group-hover/ctx:inline">
|
||||
{(workerUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
|
||||
{(workerUsage.maxTokens / 1000).toFixed(0)}k
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{/* Input area — question widget replaces textarea when a question is pending */}
|
||||
{pendingQuestions && pendingQuestions.length >= 2 && onMultiQuestionSubmit ? (
|
||||
{pendingQuestions &&
|
||||
pendingQuestions.length >= 2 &&
|
||||
onMultiQuestionSubmit ? (
|
||||
<MultiQuestionWidget
|
||||
questions={pendingQuestions}
|
||||
onSubmit={onMultiQuestionSubmit}
|
||||
@@ -366,7 +706,47 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
/>
|
||||
) : (
|
||||
<form onSubmit={handleSubmit} className="p-4">
|
||||
{/* Image preview strip */}
|
||||
{pendingImages.length > 0 && (
|
||||
<div className="flex flex-wrap gap-2 mb-2 px-1">
|
||||
{pendingImages.map((img, i) => (
|
||||
<div key={i} className="relative group">
|
||||
<img
|
||||
src={img.image_url.url}
|
||||
alt={`preview ${i + 1}`}
|
||||
className="h-16 w-16 object-cover rounded-lg border border-border"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setPendingImages((prev) => prev.filter((_, j) => j !== i))
|
||||
}
|
||||
className="absolute -top-1.5 -right-1.5 w-4 h-4 rounded-full bg-destructive text-destructive-foreground flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
>
|
||||
<X className="w-2.5 h-2.5" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center gap-3 bg-muted/40 rounded-xl px-4 py-2.5 border border-border focus-within:border-primary/40 transition-colors">
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept="image/*"
|
||||
multiple
|
||||
className="hidden"
|
||||
onChange={handleFileChange}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
disabled={disabled || !supportsImages}
|
||||
onClick={() => supportsImages && fileInputRef.current?.click()}
|
||||
className="flex-shrink-0 p-1 rounded-md text-muted-foreground hover:text-foreground disabled:opacity-30 transition-colors"
|
||||
title={supportsImages ? "Attach image" : "Image not supported by the current model"}
|
||||
>
|
||||
<Paperclip className="w-4 h-4" />
|
||||
</button>
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
rows={1}
|
||||
@@ -383,7 +763,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
handleSubmit(e);
|
||||
}
|
||||
}}
|
||||
placeholder={disabled ? "Connecting to agent..." : "Message Queen Bee..."}
|
||||
placeholder={
|
||||
disabled ? "Connecting to agent..." : "Message Queen Bee..."
|
||||
}
|
||||
disabled={disabled}
|
||||
className="flex-1 bg-transparent text-sm text-foreground outline-none placeholder:text-muted-foreground disabled:opacity-50 disabled:cursor-not-allowed resize-none overflow-y-auto"
|
||||
/>
|
||||
@@ -398,7 +780,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
) : (
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!input.trim() || disabled}
|
||||
disabled={
|
||||
(!input.trim() && pendingImages.length === 0) || disabled
|
||||
}
|
||||
className="p-2 rounded-lg bg-primary text-primary-foreground disabled:opacity-30 hover:opacity-90 transition-opacity"
|
||||
>
|
||||
<Send className="w-4 h-4" />
|
||||
|
||||
@@ -28,6 +28,13 @@ export interface SubagentReport {
|
||||
status?: "running" | "complete" | "error";
|
||||
}
|
||||
|
||||
interface ContextUsage {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
interface NodeDetailPanelProps {
|
||||
node: GraphNode | null;
|
||||
nodeSpec?: NodeSpec | null;
|
||||
@@ -38,6 +45,7 @@ interface NodeDetailPanelProps {
|
||||
workerSessionId?: string | null;
|
||||
nodeLogs?: string[];
|
||||
actionPlan?: string;
|
||||
contextUsage?: ContextUsage;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
@@ -309,7 +317,7 @@ const tabs: { id: Tab; label: string; Icon: React.FC<{ className?: string }> }[]
|
||||
{ id: "subagents", label: "Subagents", Icon: ({ className }) => <Bot className={className} /> },
|
||||
];
|
||||
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, contextUsage, onClose }: NodeDetailPanelProps) {
|
||||
const [activeTab, setActiveTab] = useState<Tab>("overview");
|
||||
const [realTools, setRealTools] = useState<ToolInfo[] | null>(null);
|
||||
const [realCriteria, setRealCriteria] = useState<NodeCriteria | null>(null);
|
||||
@@ -389,6 +397,43 @@ export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagent
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Context window usage */}
|
||||
{contextUsage && (
|
||||
<div className="px-4 py-2 border-b border-border/20 flex-shrink-0">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="text-[10px] text-muted-foreground font-medium">Context</span>
|
||||
<span className="text-[10px] text-muted-foreground/70 ml-auto">
|
||||
{(contextUsage.estimatedTokens / 1000).toFixed(1)}k / {(contextUsage.maxTokens / 1000).toFixed(0)}k tokens
|
||||
</span>
|
||||
</div>
|
||||
<div className="w-full h-1.5 rounded-full bg-muted/50 overflow-hidden">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(contextUsage.usagePct, 100)}%`,
|
||||
backgroundColor: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center gap-2 mt-1">
|
||||
<span className="text-[10px] text-muted-foreground/60">{contextUsage.messageCount} messages</span>
|
||||
<span className="text-[10px] font-medium ml-auto" style={{
|
||||
color: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}>
|
||||
{contextUsage.usagePct}%
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Tab bar */}
|
||||
<div className="flex border-b border-border/30 flex-shrink-0 px-2 pt-1 overflow-x-auto scrollbar-hide">
|
||||
{tabs.filter(t => t.id !== "subagents" || (nodeSpec?.sub_agents && nodeSpec.sub_agents.length > 0)).map(tab => (
|
||||
|
||||
@@ -0,0 +1,413 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { ChevronDown, ChevronUp, Cpu } from "lucide-react";
|
||||
import type { ChatMessage, ContextUsageEntry } from "@/components/ChatPanel";
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const workerColor = "hsl(220,60%,55%)";
|
||||
|
||||
const SUBAGENT_COLORS = [
|
||||
"hsl(220,60%,55%)",
|
||||
"hsl(260,50%,55%)",
|
||||
"hsl(180,50%,45%)",
|
||||
"hsl(30,70%,50%)",
|
||||
"hsl(340,55%,50%)",
|
||||
"hsl(150,45%,45%)",
|
||||
"hsl(45,80%,50%)",
|
||||
"hsl(290,45%,55%)",
|
||||
];
|
||||
|
||||
function colorForIndex(i: number): string {
|
||||
return SUBAGENT_COLORS[i % SUBAGENT_COLORS.length];
|
||||
}
|
||||
|
||||
function subagentLabel(nodeId: string): string {
|
||||
const parts = nodeId.split(":subagent:");
|
||||
const raw = parts.length >= 2 ? parts[1] : nodeId;
|
||||
return raw
|
||||
.replace(/:\d+$/, "") // strip instance suffix like ":3"
|
||||
.replace(/[_-]/g, " ")
|
||||
.replace(/\b\w/g, (c) => c.toUpperCase())
|
||||
.trim();
|
||||
}
|
||||
|
||||
function last<T>(arr: T[]): T | undefined {
|
||||
return arr[arr.length - 1];
|
||||
}
|
||||
|
||||
export interface SubagentGroup {
|
||||
nodeId: string;
|
||||
messages: ChatMessage[];
|
||||
contextUsage?: ContextUsageEntry;
|
||||
}
|
||||
|
||||
interface ParallelSubagentBubbleProps {
|
||||
groups: SubagentGroup[];
|
||||
groupId: string;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Thermometer — vertical context gauge on right edge of each pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool overlay — shown when a tool_status message is active (not all done)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ToolOverlay({
|
||||
toolName,
|
||||
color,
|
||||
visible,
|
||||
}: {
|
||||
toolName: string;
|
||||
color: string;
|
||||
visible: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
className="absolute inset-0 top-[22px] flex items-center justify-center transition-opacity duration-200 z-10"
|
||||
style={{
|
||||
background: "rgba(8,8,14,0.82)",
|
||||
opacity: visible ? 1 : 0,
|
||||
pointerEvents: visible ? "auto" : "none",
|
||||
}}
|
||||
>
|
||||
<div className="text-center px-3 py-2 rounded-md border" style={{ borderColor: `${color}40` }}>
|
||||
<div className="text-[10px] font-medium" style={{ color }}>
|
||||
{toolName}
|
||||
</div>
|
||||
<div className="text-[11px] mt-0.5" style={{ color }}>
|
||||
{visible ? "..." : "\u2713"}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Single tmux pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function MuxPane({
|
||||
group,
|
||||
index,
|
||||
label,
|
||||
isFocused,
|
||||
isZoomed,
|
||||
onClickTitle,
|
||||
}: {
|
||||
group: SubagentGroup;
|
||||
index: number;
|
||||
label: string;
|
||||
isFocused: boolean;
|
||||
isZoomed: boolean;
|
||||
onClickTitle: () => void;
|
||||
}) {
|
||||
const bodyRef = useRef<HTMLDivElement>(null);
|
||||
const stickRef = useRef(true);
|
||||
const color = colorForIndex(index);
|
||||
const pct = group.contextUsage?.usagePct ?? 0;
|
||||
|
||||
const streamMsgs = group.messages.filter((m) => m.type !== "tool_status");
|
||||
const latestContent = last(streamMsgs)?.content ?? "";
|
||||
const msgCount = streamMsgs.length;
|
||||
|
||||
// Detect active tool and finished state from latest tool_status
|
||||
const latestTool = last(
|
||||
group.messages.filter((m) => m.type === "tool_status")
|
||||
);
|
||||
let activeToolName = "";
|
||||
let toolRunning = false;
|
||||
let isFinished = false;
|
||||
if (latestTool) {
|
||||
try {
|
||||
const parsed = JSON.parse(latestTool.content);
|
||||
const tools: { name: string; done: boolean }[] = parsed.tools || [];
|
||||
const allDone = parsed.allDone as boolean | undefined;
|
||||
const running = tools.find((t) => !t.done);
|
||||
if (running) {
|
||||
activeToolName = running.name;
|
||||
toolRunning = true;
|
||||
}
|
||||
// Finished when all tools are done and one of them is set_output
|
||||
// or report_to_parent (terminal tool calls)
|
||||
if (allDone && tools.length > 0) {
|
||||
const hasTerminal = tools.some(
|
||||
(t) =>
|
||||
t.done &&
|
||||
(t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
if (hasTerminal) isFinished = true;
|
||||
}
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-scroll
|
||||
useEffect(() => {
|
||||
if (stickRef.current && bodyRef.current) {
|
||||
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
|
||||
}
|
||||
}, [latestContent]);
|
||||
|
||||
const handleScroll = () => {
|
||||
const el = bodyRef.current;
|
||||
if (!el) return;
|
||||
stickRef.current = el.scrollHeight - el.scrollTop - el.clientHeight < 30;
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col min-h-0 overflow-hidden relative transition-all duration-200"
|
||||
style={{
|
||||
borderWidth: 1,
|
||||
borderStyle: "solid",
|
||||
borderColor: isFocused && !isFinished ? `${color}60` : "transparent",
|
||||
opacity: isFinished ? 0.4 : isFocused || isZoomed ? 1 : 0.55,
|
||||
...(isZoomed
|
||||
? { gridColumn: "1 / -1", gridRow: "1 / -1", zIndex: 10 }
|
||||
: {}),
|
||||
}}
|
||||
>
|
||||
{/* Title bar */}
|
||||
<div
|
||||
className="flex items-center gap-1.5 px-2 py-[3px] flex-shrink-0 cursor-pointer select-none"
|
||||
style={{ background: "#0e0e16", borderBottom: "1px solid #1a1a2a" }}
|
||||
onClick={onClickTitle}
|
||||
>
|
||||
{isFinished ? (
|
||||
<span className="text-[8px] flex-shrink-0 leading-none" style={{ color: "#4a4" }}>✓</span>
|
||||
) : (
|
||||
<div
|
||||
className="w-[6px] h-[6px] rounded-full flex-shrink-0"
|
||||
style={{ background: color }}
|
||||
/>
|
||||
)}
|
||||
<span className="text-[9px] flex-shrink-0" style={{ color: isFinished ? "#555" : color }}>
|
||||
{label}
|
||||
</span>
|
||||
<span className="flex-1" />
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{msgCount}
|
||||
</span>
|
||||
<div
|
||||
className="w-[36px] h-[3px] rounded-full overflow-hidden flex-shrink-0"
|
||||
style={{ background: "#1a1a2a" }}
|
||||
>
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500"
|
||||
style={{
|
||||
width: `${Math.min(pct, 100)}%`,
|
||||
backgroundColor:
|
||||
pct >= 80 ? "hsl(0,65%,55%)" : pct >= 50 ? "hsl(35,90%,55%)" : color,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{pct}%
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Body */}
|
||||
<div
|
||||
ref={bodyRef}
|
||||
onScroll={handleScroll}
|
||||
className="flex-1 min-h-0 overflow-y-auto px-2 py-1 text-[10px] leading-[1.7]"
|
||||
style={{ background: "#08080e", color: "#555", fontFamily: "monospace" }}
|
||||
>
|
||||
{latestContent ? (
|
||||
<div style={{ color: "#ccc" }}>
|
||||
<MarkdownContent content={latestContent} />
|
||||
</div>
|
||||
) : (
|
||||
<span style={{ color: "#333" }}>waiting...</span>
|
||||
)}
|
||||
{/* Blinking cursor — hidden when finished */}
|
||||
{!isFinished && (
|
||||
<span
|
||||
className="inline-block w-[6px] h-[11px] align-middle ml-0.5"
|
||||
style={{
|
||||
background: color,
|
||||
animation: "cursorBlink 1s step-end infinite",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Tool overlay */}
|
||||
<ToolOverlay
|
||||
toolName={activeToolName}
|
||||
color={color}
|
||||
visible={toolRunning}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const ParallelSubagentBubble = memo(
|
||||
function ParallelSubagentBubble({ groups }: ParallelSubagentBubbleProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [zoomedIdx, setZoomedIdx] = useState<number | null>(null);
|
||||
|
||||
// Labels with instance numbers for duplicates
|
||||
const labels: string[] = (() => {
|
||||
const countByBase = new Map<string, number>();
|
||||
const bases = groups.map((g) => subagentLabel(g.nodeId));
|
||||
for (const b of bases)
|
||||
countByBase.set(b, (countByBase.get(b) ?? 0) + 1);
|
||||
const idxByBase = new Map<string, number>();
|
||||
return bases.map((b) => {
|
||||
if ((countByBase.get(b) ?? 1) <= 1) return b;
|
||||
const idx = (idxByBase.get(b) ?? 0) + 1;
|
||||
idxByBase.set(b, idx);
|
||||
return `${b} #${idx}`;
|
||||
});
|
||||
})();
|
||||
|
||||
// Latest-active pane
|
||||
const latestIdx = groups.reduce<number>((best, g, i) => {
|
||||
const filtered = g.messages.filter((m) => m.type !== "tool_status");
|
||||
const lm = last(filtered);
|
||||
if (!lm) return best;
|
||||
if (best < 0) return i;
|
||||
const bm = last(
|
||||
groups[best].messages.filter((m) => m.type !== "tool_status")
|
||||
);
|
||||
if (!bm) return i;
|
||||
return (lm.createdAt ?? 0) >= (bm.createdAt ?? 0) ? i : best;
|
||||
}, -1);
|
||||
|
||||
// Per-group finished detection (same logic as MuxPane)
|
||||
const finishedFlags = groups.map((g) => {
|
||||
const lt = last(g.messages.filter((m) => m.type === "tool_status"));
|
||||
if (!lt) return false;
|
||||
try {
|
||||
const p = JSON.parse(lt.content);
|
||||
const tools: { name: string; done: boolean }[] = p.tools || [];
|
||||
if (!p.allDone || tools.length === 0) return false;
|
||||
return tools.some(
|
||||
(t) => t.done && (t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
} catch { return false; }
|
||||
});
|
||||
const activeCount = finishedFlags.filter((f) => !f).length;
|
||||
|
||||
if (groups.length === 0) return null;
|
||||
|
||||
// Grid sizing: 2 columns, auto rows capped at a fixed height
|
||||
const rows = Math.ceil(groups.length / 2);
|
||||
const gridHeight = expanded
|
||||
? Math.min(rows * 200, 480)
|
||||
: Math.min(rows * 100, 240);
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
{/* Left icon */}
|
||||
<div
|
||||
className="flex-shrink-0 w-7 h-7 rounded-xl flex items-center justify-center mt-1"
|
||||
style={{
|
||||
backgroundColor: `${workerColor}18`,
|
||||
border: `1.5px solid ${workerColor}35`,
|
||||
}}
|
||||
>
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color: workerColor }} />
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0 max-w-[90%]">
|
||||
{/* Header */}
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="font-medium text-xs" style={{ color: workerColor }}>
|
||||
{groups.length === 1 ? "Sub-agent" : "Parallel Agents"}
|
||||
</span>
|
||||
<span className="text-[10px] font-medium px-1.5 py-0.5 rounded-md bg-muted text-muted-foreground">
|
||||
{activeCount > 0 ? `${activeCount} running` : `${groups.length} done`}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => {
|
||||
setExpanded((v) => !v);
|
||||
setZoomedIdx(null);
|
||||
}}
|
||||
className="ml-auto text-muted-foreground/60 hover:text-muted-foreground transition-colors p-0.5 rounded"
|
||||
title={expanded ? "Collapse" : "Expand"}
|
||||
>
|
||||
{expanded ? (
|
||||
<ChevronUp className="w-3.5 h-3.5" />
|
||||
) : (
|
||||
<ChevronDown className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Mux frame */}
|
||||
<div
|
||||
className="rounded-lg overflow-hidden"
|
||||
style={{
|
||||
border: "2px solid #1a1a2a",
|
||||
background: "#08080e",
|
||||
}}
|
||||
>
|
||||
{/* Grid */}
|
||||
<div
|
||||
className="grid gap-px"
|
||||
style={{
|
||||
gridTemplateColumns:
|
||||
groups.length === 1 ? "1fr" : "1fr 1fr",
|
||||
gridTemplateRows: `repeat(${rows}, 1fr)`,
|
||||
height: gridHeight,
|
||||
background: "#111",
|
||||
}}
|
||||
>
|
||||
{groups.map((group, i) => (
|
||||
<MuxPane
|
||||
key={group.nodeId}
|
||||
group={group}
|
||||
index={i}
|
||||
label={labels[i]}
|
||||
isFocused={latestIdx === i}
|
||||
isZoomed={zoomedIdx === i}
|
||||
onClickTitle={() =>
|
||||
setZoomedIdx(zoomedIdx === i ? null : i)
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
(prev, next) =>
|
||||
prev.groupId === next.groupId &&
|
||||
prev.groups.length === next.groups.length &&
|
||||
prev.groups.every(
|
||||
(g, i) =>
|
||||
g.nodeId === next.groups[i].nodeId &&
|
||||
g.messages.length === next.groups[i].messages.length &&
|
||||
last(g.messages)?.content === last(next.groups[i].messages)?.content &&
|
||||
g.contextUsage?.usagePct === next.groups[i].contextUsage?.usagePct
|
||||
)
|
||||
);
|
||||
|
||||
export default ParallelSubagentBubble;
|
||||
|
||||
// Injected as a global style (keyframes can't be inline)
|
||||
if (typeof document !== "undefined") {
|
||||
const id = "parallel-subagent-keyframes";
|
||||
if (!document.getElementById(id)) {
|
||||
const style = document.createElement("style");
|
||||
style.id = id;
|
||||
style.textContent = `
|
||||
@keyframes cursorBlink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
|
||||
@keyframes thermoPulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.4; } }
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ export function sseEventToChatMessage(
|
||||
const innerSuffix = innerTurn != null && innerTurn > 0 ? `-t${innerTurn}` : "";
|
||||
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
if (!snapshot) return null;
|
||||
if (!snapshot.trim()) return null;
|
||||
return {
|
||||
id: `stream-${iterIdKey}${innerSuffix}-${event.node_id}`,
|
||||
agent: agentDisplayName || event.node_id || "Agent",
|
||||
@@ -72,6 +72,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -100,7 +102,7 @@ export function sseEventToChatMessage(
|
||||
const llmInnerSuffix = llmInnerTurn != null && llmInnerTurn > 0 ? `-t${llmInnerTurn}` : "";
|
||||
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
if (!snapshot) return null;
|
||||
if (!snapshot.trim()) return null;
|
||||
return {
|
||||
id: `stream-${idKey}${llmInnerSuffix}-${event.node_id}`,
|
||||
agent: event.node_id || "Agent",
|
||||
@@ -110,6 +112,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import { Link } from "react-router-dom";
|
||||
|
||||
export default function NotFound() {
|
||||
return (
|
||||
<div className="min-h-screen bg-background flex flex-col items-center justify-center px-6 text-center">
|
||||
<h1 className="text-5xl font-semibold text-foreground">404</h1>
|
||||
<p className="mt-3 text-sm text-muted-foreground">Page not found</p>
|
||||
<p className="mt-1 text-sm text-muted-foreground/80">
|
||||
The page you’re looking for doesn’t exist.
|
||||
</p>
|
||||
<Link
|
||||
to="/"
|
||||
className="mt-6 inline-flex items-center rounded-lg border border-border/40 px-4 py-2 text-sm font-medium text-foreground hover:bg-muted/40 transition-colors"
|
||||
>
|
||||
Back to Home
|
||||
</Link>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
|
||||
import ReactDOM from "react-dom";
|
||||
import { useSearchParams, useNavigate } from "react-router-dom";
|
||||
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X } from "lucide-react";
|
||||
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X, FolderOpen } from "lucide-react";
|
||||
import type { GraphNode, NodeStatus } from "@/components/graph-types";
|
||||
import DraftGraph from "@/components/DraftGraph";
|
||||
import ChatPanel, { type ChatMessage } from "@/components/ChatPanel";
|
||||
@@ -352,6 +352,10 @@ interface AgentBackendState {
|
||||
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
|
||||
/** Whether the pending question came from queen or worker */
|
||||
pendingQuestionSource: "queen" | "worker" | null;
|
||||
/** Per-node context window usage (from context_usage_updated events) */
|
||||
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
|
||||
/** Whether the queen's LLM supports image content — false disables the attach button */
|
||||
queenSupportsImages: boolean;
|
||||
}
|
||||
|
||||
function defaultAgentState(): AgentBackendState {
|
||||
@@ -389,6 +393,8 @@ function defaultAgentState(): AgentBackendState {
|
||||
pendingOptions: null,
|
||||
pendingQuestions: null,
|
||||
pendingQuestionSource: null,
|
||||
contextUsage: {},
|
||||
queenSupportsImages: true,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -630,6 +636,10 @@ export default function Workspace() {
|
||||
// it was created in (avoids stale-closure when phase change and message
|
||||
// events arrive in the same React batch).
|
||||
const queenPhaseRef = useRef<Record<string, string>>({});
|
||||
// Accumulated queen text across inner_turns within the same iteration.
|
||||
// Key: `${agentType}:${execution_id}:${iteration}`, value: { [inner_turn]: snapshot }.
|
||||
// This lets us merge all inner_turn text into one chat bubble per iteration.
|
||||
const queenIterTextRef = useRef<Record<string, Record<number, string>>>({});
|
||||
// Timestamp when designingDraft was set — used to enforce minimum spinner duration.
|
||||
const designingDraftSinceRef = useRef<Record<string, number>>({});
|
||||
const designingDraftTimerRef = useRef<Record<string, ReturnType<typeof setTimeout>>>({});
|
||||
@@ -916,6 +926,7 @@ export default function Workspace() {
|
||||
queenReady: true,
|
||||
queenPhase: qPhase,
|
||||
queenBuilding: qPhase === "building",
|
||||
queenSupportsImages: liveSession.queen_supports_images !== false,
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
@@ -1115,6 +1126,7 @@ export default function Workspace() {
|
||||
displayName,
|
||||
queenPhase: initialPhase,
|
||||
queenBuilding: initialPhase === "building",
|
||||
queenSupportsImages: session.queen_supports_images !== false,
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
@@ -1707,14 +1719,29 @@ export default function Workspace() {
|
||||
if (isQueen) console.log('[QUEEN] chatMsg:', chatMsg?.id, chatMsg?.content?.slice(0, 50), 'turn:', currentTurn);
|
||||
if (chatMsg && !suppressQueenMessages) {
|
||||
// Queen emits multiple client_output_delta / llm_text_delta snapshots
|
||||
// across iterations and inner tool-loop turns. Build a stable ID that
|
||||
// groups streaming deltas for the *same* output (same execution +
|
||||
// iteration + inner_turn) into one bubble, while keeping distinct
|
||||
// outputs as separate bubbles so earlier text isn't overwritten.
|
||||
// across iterations and inner tool-loop turns. Merge all inner_turns
|
||||
// within the same iteration into ONE bubble so the queen's multi-step
|
||||
// tool loop (text → tool → text → tool → text) appears as one cohesive
|
||||
// message rather than many small fragments.
|
||||
if (isQueen && (event.type === "client_output_delta" || event.type === "llm_text_delta") && event.execution_id) {
|
||||
const iter = event.data?.iteration ?? 0;
|
||||
const inner = event.data?.inner_turn ?? 0;
|
||||
chatMsg.id = `queen-stream-${event.execution_id}-${iter}-${inner}`;
|
||||
const inner = (event.data?.inner_turn as number) ?? 0;
|
||||
const iterKey = `${agentType}:${event.execution_id}:${iter}`;
|
||||
|
||||
// Store the latest snapshot for this inner_turn
|
||||
if (!queenIterTextRef.current[iterKey]) {
|
||||
queenIterTextRef.current[iterKey] = {};
|
||||
}
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
queenIterTextRef.current[iterKey][inner] = snapshot;
|
||||
|
||||
// Concatenate all inner_turn snapshots in order
|
||||
const parts = queenIterTextRef.current[iterKey];
|
||||
const sortedInners = Object.keys(parts).map(Number).sort((a, b) => a - b);
|
||||
chatMsg.content = sortedInners.map(k => parts[k]).join("\n");
|
||||
|
||||
// Single ID per iteration — no inner_turn in the ID
|
||||
chatMsg.id = `queen-stream-${event.execution_id}-${iter}`;
|
||||
}
|
||||
if (isQueen) {
|
||||
chatMsg.role = role;
|
||||
@@ -1989,6 +2016,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -2060,6 +2089,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -2136,6 +2167,29 @@ export default function Workspace() {
|
||||
}
|
||||
break;
|
||||
|
||||
case "context_usage_updated": {
|
||||
const streamKey = isQueen ? "__queen__" : (event.node_id || streamId);
|
||||
const usagePct = (event.data?.usage_pct as number) ?? 0;
|
||||
const messageCount = (event.data?.message_count as number) ?? 0;
|
||||
const estimatedTokens = (event.data?.estimated_tokens as number) ?? 0;
|
||||
const maxTokens = (event.data?.max_context_tokens as number) ?? 0;
|
||||
setAgentStates(prev => {
|
||||
const state = prev[agentType];
|
||||
if (!state) return prev;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: {
|
||||
...state,
|
||||
contextUsage: {
|
||||
...state.contextUsage,
|
||||
[streamKey]: { usagePct, messageCount, estimatedTokens, maxTokens },
|
||||
},
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "node_action_plan":
|
||||
if (!isQueen && event.node_id) {
|
||||
const plan = (event.data?.plan as string) || "";
|
||||
@@ -2564,7 +2618,7 @@ export default function Workspace() {
|
||||
});
|
||||
|
||||
// --- handleSend ---
|
||||
const handleSend = useCallback((text: string, thread: string) => {
|
||||
const handleSend = useCallback((text: string, thread: string, images?: import("@/components/ChatPanel").ImageContent[]) => {
|
||||
if (!activeSession) return;
|
||||
const state = agentStates[activeWorker];
|
||||
|
||||
@@ -2630,6 +2684,7 @@ export default function Workspace() {
|
||||
const userMsg: ChatMessage = {
|
||||
id: makeId(), agent: "You", agentColor: "",
|
||||
content: text, timestamp: "", type: "user", thread, createdAt: Date.now(),
|
||||
images,
|
||||
};
|
||||
setSessionsByAgent(prev => ({
|
||||
...prev,
|
||||
@@ -2641,7 +2696,7 @@ export default function Workspace() {
|
||||
updateAgentState(activeWorker, { isTyping: true, queenIsTyping: true });
|
||||
|
||||
if (state?.sessionId && state?.ready) {
|
||||
executionApi.chat(state.sessionId, text).catch((err: unknown) => {
|
||||
executionApi.chat(state.sessionId, text, images).catch((err: unknown) => {
|
||||
const errMsg = err instanceof Error ? err.message : String(err);
|
||||
const errorChatMsg: ChatMessage = {
|
||||
id: makeId(), agent: "System", agentColor: "",
|
||||
@@ -3057,6 +3112,16 @@ export default function Workspace() {
|
||||
<KeyRound className="w-3.5 h-3.5" />
|
||||
Credentials
|
||||
</button>
|
||||
{activeAgentState?.sessionId && (
|
||||
<button
|
||||
onClick={() => sessionsApi.revealFolder(activeAgentState.sessionId!).catch(() => {})}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/50 transition-colors flex-shrink-0"
|
||||
title="Open session data folder"
|
||||
>
|
||||
<FolderOpen className="w-3.5 h-3.5" />
|
||||
Data
|
||||
</button>
|
||||
)}
|
||||
</TopBar>
|
||||
|
||||
{/* Main content area */}
|
||||
@@ -3174,6 +3239,8 @@ export default function Workspace() {
|
||||
}
|
||||
onMultiQuestionSubmit={handleMultiQuestionAnswer}
|
||||
onQuestionDismiss={handleQuestionDismiss}
|
||||
contextUsage={activeAgentState?.contextUsage}
|
||||
supportsImages={activeAgentState?.queenSupportsImages ?? true}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -3377,6 +3444,7 @@ export default function Workspace() {
|
||||
workerSessionId={null}
|
||||
nodeLogs={activeAgentState?.nodeLogs[resolvedSelectedNode.id] || []}
|
||||
actionPlan={activeAgentState?.nodeActionPlans[resolvedSelectedNode.id]}
|
||||
contextUsage={activeAgentState?.contextUsage[resolvedSelectedNode.id]}
|
||||
onClose={() => setSelectedNode(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user