Compare commits
218 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fd4dc1a69a | |||
| 377cd39c2a | |||
| e92caeef24 | |||
| a995818db2 | |||
| 0772b4d300 | |||
| 684e0d8dc6 | |||
| d284c5d790 | |||
| 7a9b9666c4 | |||
| a852cb91bf | |||
| 2f21e9eb4b | |||
| 8390ef8731 | |||
| 44b3e0eaa2 | |||
| 22b7e4b0c3 | |||
| 5413833a69 | |||
| 02e1a4584a | |||
| 520840b1dd | |||
| ee96147336 | |||
| 705cef4dc1 | |||
| ab26e64122 | |||
| f365e219cb | |||
| 01621881c2 | |||
| f7639f8572 | |||
| fc643060ce | |||
| 9aebeb181e | |||
| acbbfaaa79 | |||
| bf170bce10 | |||
| 0a090d058b | |||
| 47bfadaad9 | |||
| d968dcd44c | |||
| 6fdaa9ea50 | |||
| 4d251fbdc2 | |||
| 6acceed288 | |||
| 8dd1d6e3aa | |||
| 1da28644a6 | |||
| 6452fe7fef | |||
| acff008bd2 | |||
| 651d6850a1 | |||
| c7fdc92594 | |||
| 43602a8801 | |||
| 3da04265a6 | |||
| 4c98f0d2d0 | |||
| ae921f6cee | |||
| 6b506a1c08 | |||
| 0c9f4fa97e | |||
| 95e30bc607 | |||
| 23c66d1059 | |||
| b9d529d94e | |||
| 1c9b09fb78 | |||
| 9fb14f23d2 | |||
| 4795dc4f68 | |||
| acf0f804c5 | |||
| 4e2951854b | |||
| 80dfb429d7 | |||
| 9c0ba77e22 | |||
| 46b4651073 | |||
| 86dd5246c6 | |||
| a1227c88ee | |||
| 535d7ab568 | |||
| af10494b31 | |||
| 39c1042827 | |||
| 16e7dc11f4 | |||
| 7a27babefd | |||
| d53ae9d51d | |||
| 910cf7727d | |||
| 1698605f15 | |||
| eda124a123 | |||
| 15e9ce8d2f | |||
| c01dd603d7 | |||
| 9d5157d69f | |||
| d78795bdf5 | |||
| ff2b7f473e | |||
| 73c9a91811 | |||
| 27b765d902 | |||
| fddba419be | |||
| f42d6308e8 | |||
| c167002754 | |||
| ea26ee7d0c | |||
| 5280e908b2 | |||
| 1c5dd8c664 | |||
| 3aca153be5 | |||
| 65c8e1653c | |||
| 58e4fa918c | |||
| 3af13d3f90 | |||
| b799789dbe | |||
| 2cd73dfccc | |||
| 57d77d5479 | |||
| 5814021773 | |||
| 4f4cc9c8ce | |||
| d9c840eee5 | |||
| d2eb86e534 | |||
| 03842353e4 | |||
| 48747e20af | |||
| 58af593af6 | |||
| 450575a927 | |||
| eac2bb19b2 | |||
| 756a815bf0 | |||
| 23a7b080eb | |||
| bf39bcdec9 | |||
| 0276632491 | |||
| ae2993d0d1 | |||
| d14d71f760 | |||
| ef6efc2f55 | |||
| 738641d35f | |||
| 22f5534f08 | |||
| b79e7eca73 | |||
| 28250dc45e | |||
| fe5df6a87a | |||
| 07e4b593dd | |||
| 497591bf3b | |||
| a2a3e334d6 | |||
| 1ccbfaf800 | |||
| a9afa0555c | |||
| 83b2183cf0 | |||
| c2dea88398 | |||
| f49e7a760e | |||
| dc95c88da0 | |||
| 6e0255ebec | |||
| b51e688d1a | |||
| 379d3df46b | |||
| b77a3031fe | |||
| c10eea04ec | |||
| 491a3f24da | |||
| c7d70e0fb1 | |||
| d59f8e99cb | |||
| 0a91b49417 | |||
| ced64541b9 | |||
| 88253883a3 | |||
| 3c30cfe02b | |||
| 0d6267bcf1 | |||
| b47175d1df | |||
| 6f23a30eed | |||
| ff7b5c7e27 | |||
| 69f0ff7ac9 | |||
| c3f13c50eb | |||
| 5477408d40 | |||
| 9fad385ddf | |||
| cf44ee1d9b | |||
| 4ab33a39d6 | |||
| ae19121802 | |||
| b518525418 | |||
| ac3fe38b33 | |||
| 3c6a30fcae | |||
| 2ced873fb5 | |||
| 6ed6e5b286 | |||
| 30bb0ad5d8 | |||
| cb0845f5ba | |||
| ce2525b59c | |||
| 1f77ec3831 | |||
| ab995d8b96 | |||
| 6ab5aa8004 | |||
| 4449cd8ee8 | |||
| 8b60c03a0a | |||
| c2e560fc07 | |||
| 19f7ae862e | |||
| 5e9f74744a | |||
| 0e98023e40 | |||
| 7787179a5a | |||
| b63205b91a | |||
| 347bccb9ee | |||
| 22bb07f00e | |||
| 660f883197 | |||
| 9d83f0298f | |||
| 988de80b66 | |||
| dc6aa226ee | |||
| 48a54b4ee2 | |||
| 7f7e8b4dff | |||
| f48a7380f5 | |||
| 3c7f129d86 | |||
| 4533b27aa1 | |||
| 3adf268c29 | |||
| ac8579900f | |||
| abbaaa68f3 | |||
| 11089093ef | |||
| 99b7cb07d5 | |||
| 70d61ae67a | |||
| dd054815a3 | |||
| 8e5eaae9dd | |||
| 2d0128eb5c | |||
| 06f1d4dcef | |||
| 0e7b11b5b2 | |||
| 291b78f934 | |||
| e196a03972 | |||
| a0abe2685d | |||
| e8f642c8b6 | |||
| 6260f628eb | |||
| 4a4f17ed40 | |||
| 36dcf2025b | |||
| 85c70c94e6 | |||
| 336e82ba22 | |||
| a7b6b080ab | |||
| 9202cbd4d4 | |||
| f2ddd1051d | |||
| 2dd60c8d52 | |||
| ff01c1fd99 | |||
| 421b25fdb7 | |||
| 795c3c33e2 | |||
| 97821f4d80 | |||
| 505e1e30fd | |||
| 3fb2b285fb | |||
| a76109840c | |||
| 1db8484402 | |||
| 39212350ba | |||
| 7ede3ba171 | |||
| cdaec8a837 | |||
| 635d2976f4 | |||
| 4e1525880d | |||
| 20427e213a | |||
| 094ba89f19 | |||
| 7008c9f310 | |||
| 94d7cbacc2 | |||
| bddc2b413a | |||
| 48c8fb7fff | |||
| bc3c5a5899 | |||
| e82133741c | |||
| 5076278dcb | |||
| 2398e04e11 | |||
| d00f321627 | |||
| e76b6cb575 |
@@ -2,14 +2,22 @@ name: Bounty completed
|
||||
description: Awards points and notifies Discord when a bounty PR is merged
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr_number:
|
||||
description: "PR number to process (for missed bounties)"
|
||||
required: true
|
||||
type: number
|
||||
|
||||
jobs:
|
||||
bounty-notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:')
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.event.pull_request.merged == true &&
|
||||
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:'))
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
@@ -32,6 +40,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_NUMBER: ${{ inputs.pr_number || github.event.pull_request.number }}
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
name: Link Discord account
|
||||
description: Auto-creates a PR to add contributor to contributors.yml when a link-discord issue is opened
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
link-discord:
|
||||
if: contains(github.event.issue.labels.*.name, 'link-discord')
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 2
|
||||
permissions:
|
||||
contents: write
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Parse issue and update contributors.yml
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
|
||||
const issue = context.payload.issue;
|
||||
const githubUsername = issue.user.login;
|
||||
|
||||
// Parse the issue body for form fields
|
||||
const body = issue.body || '';
|
||||
|
||||
// Extract Discord ID — look for the numeric value after the "Discord User ID" heading
|
||||
const discordMatch = body.match(/### Discord User ID\s*\n\s*(\d{17,20})/);
|
||||
if (!discordMatch) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `Could not find a valid Discord ID in the issue body. Please make sure you entered a numeric ID (17-20 digits), not a username.\n\nExample: \`123456789012345678\``
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'not_planned'
|
||||
});
|
||||
return;
|
||||
}
|
||||
const discordId = discordMatch[1];
|
||||
|
||||
// Extract display name (optional)
|
||||
const nameMatch = body.match(/### Display Name \(optional\)\s*\n\s*(.+)/);
|
||||
const displayName = nameMatch ? nameMatch[1].trim() : '';
|
||||
|
||||
// Check if user already exists
|
||||
const yml = fs.readFileSync('contributors.yml', 'utf-8');
|
||||
if (yml.includes(`github: ${githubUsername}`)) {
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
body: `@${githubUsername} is already in \`contributors.yml\`. If you need to update your Discord ID, please edit the file directly via PR.`
|
||||
});
|
||||
await github.rest.issues.update({
|
||||
...context.repo,
|
||||
issue_number: issue.number,
|
||||
state: 'closed',
|
||||
state_reason: 'completed'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Append entry to contributors.yml
|
||||
let entry = ` - github: ${githubUsername}\n discord: "${discordId}"`;
|
||||
if (displayName && displayName !== '_No response_') {
|
||||
entry += `\n name: ${displayName}`;
|
||||
}
|
||||
entry += '\n';
|
||||
|
||||
const updated = yml.trimEnd() + '\n' + entry;
|
||||
fs.writeFileSync('contributors.yml', updated);
|
||||
|
||||
// Set outputs for commit step
|
||||
core.exportVariable('GITHUB_USERNAME', githubUsername);
|
||||
core.exportVariable('DISCORD_ID', discordId);
|
||||
core.exportVariable('ISSUE_NUMBER', issue.number.toString());
|
||||
|
||||
- name: Create PR
|
||||
run: |
|
||||
# Check if there are changes
|
||||
if git diff --quiet contributors.yml; then
|
||||
echo "No changes to contributors.yml"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
BRANCH="docs/link-discord-${GITHUB_USERNAME}"
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
git checkout -b "$BRANCH"
|
||||
git add contributors.yml
|
||||
git commit -m "docs: link @${GITHUB_USERNAME} to Discord"
|
||||
git push origin "$BRANCH"
|
||||
|
||||
gh pr create \
|
||||
--title "docs: link @${GITHUB_USERNAME} to Discord" \
|
||||
--body "Adds @${GITHUB_USERNAME} (Discord \`${DISCORD_ID}\`) to \`contributors.yml\` for bounty XP tracking.
|
||||
|
||||
Closes #${ISSUE_NUMBER}" \
|
||||
--base main \
|
||||
--head "$BRANCH" \
|
||||
--label "link-discord"
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Notify on issue
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const username = process.env.GITHUB_USERNAME;
|
||||
const issueNumber = parseInt(process.env.ISSUE_NUMBER);
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issueNumber,
|
||||
body: `A PR has been created to link your account. A maintainer will merge it shortly — once merged, you'll receive XP and Discord pings when your bounty PRs are merged.`
|
||||
});
|
||||
@@ -35,6 +35,8 @@ jobs:
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
|
||||
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
|
||||
BOT_API_URL: ${{ secrets.BOT_API_URL }}
|
||||
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
|
||||
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
|
||||
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
|
||||
SINCE_DATE: ${{ github.event.inputs.since_date || '' }}
|
||||
|
||||
@@ -68,7 +68,6 @@ temp/
|
||||
exports/*
|
||||
|
||||
.claude/settings.local.json
|
||||
.claude/skills/ship-it/
|
||||
|
||||
.venv
|
||||
|
||||
|
||||
+150
-27
@@ -1,17 +1,149 @@
|
||||
# Release Notes
|
||||
|
||||
## v0.7.1
|
||||
|
||||
**Release Date:** March 13, 2026
|
||||
**Tag:** v0.7.1
|
||||
|
||||
### Chrome-Native Browser Control
|
||||
|
||||
v0.7.1 replaces Playwright with direct Chrome DevTools Protocol (CDP) integration. The GCU now launches the user's system Chrome via `open -n` on macOS, connects over CDP, and manages browser lifecycle end-to-end -- no extra browser binary required.
|
||||
|
||||
---
|
||||
|
||||
### Highlights
|
||||
|
||||
#### System Chrome via CDP
|
||||
|
||||
The entire GCU browser stack has been rewritten:
|
||||
|
||||
- **Chrome finder & launcher** -- New `chrome_finder.py` discovers installed Chrome and `chrome_launcher.py` manages process lifecycle with `--remote-debugging-port`
|
||||
- **Coexist with user's browser** -- `open -n` on macOS launches a separate Chrome instance so the user's tabs stay untouched
|
||||
- **Dynamic viewport sizing** -- Viewport auto-sizes to the available display area, suppressing Chrome warning bars
|
||||
- **Orphan cleanup** -- Chrome processes are killed on GCU server shutdown to prevent leaks
|
||||
- **`--no-startup-window`** -- Chrome launches headlessly by default until a page is needed
|
||||
|
||||
#### Per-Subagent Browser Isolation
|
||||
|
||||
Each GCU subagent gets its own Chrome user-data directory, preventing cookie/session cross-contamination:
|
||||
|
||||
- Unique browser profiles injected per subagent
|
||||
- Profiles cleaned up after top-level GCU node execution
|
||||
- Tab origin and age metadata tracked per subagent
|
||||
|
||||
#### Dummy Agent Testing Framework
|
||||
|
||||
A comprehensive test suite for validating agent graph patterns without LLM calls:
|
||||
|
||||
- 8 test modules covering echo, pipeline, branch, parallel merge, retry, feedback loop, worker, and GCU subagent patterns
|
||||
- Shared fixtures and a `run_all.py` runner for CI integration
|
||||
- Subagent lifecycle tests
|
||||
|
||||
---
|
||||
|
||||
### What's New
|
||||
|
||||
#### GCU Browser
|
||||
|
||||
- **Switch from Playwright to system Chrome via CDP** -- Direct CDP connection replaces Playwright dependency. (@bryanadenhq)
|
||||
- **Chrome finder and launcher modules** -- `chrome_finder.py` and `chrome_launcher.py` for cross-platform Chrome discovery and process management. (@bryanadenhq)
|
||||
- **Dynamic viewport sizing** -- Auto-size viewport and suppress Chrome warning bar. (@bryanadenhq)
|
||||
- **Per-subagent browser profile isolation** -- Unique user-data directories per subagent with cleanup. (@bryanadenhq)
|
||||
- **Tab origin/age metadata** -- Track which subagent opened each tab and when. (@bryanadenhq)
|
||||
- **`browser_close_all` tool** -- Bulk tab cleanup for agents managing many pages. (@bryanadenhq)
|
||||
- **Auto-track popup pages** -- Popups are automatically captured and tracked. (@bryanadenhq)
|
||||
- **Auto-snapshot from browser interactions** -- Browser interaction tools return screenshots automatically. (@bryanadenhq)
|
||||
- **Kill orphaned Chrome processes** -- GCU server shutdown cleans up lingering Chrome instances. (@bryanadenhq)
|
||||
- **`--no-startup-window` Chrome flag** -- Prevent empty window on launch. (@bryanadenhq)
|
||||
- **Launch Chrome via `open -n` on macOS** -- Coexist with the user's running browser. (@bryanadenhq)
|
||||
|
||||
#### Framework & Runtime
|
||||
|
||||
- **Session resume fix for new agents** -- Correctly resume sessions when a new agent is loaded. (@bryanadenhq)
|
||||
- **Queen upsert fix** -- Prevent duplicate queen entries on session restore. (@bryanadenhq)
|
||||
- **Anchor worker monitoring to queen's session ID on cold-restore** -- Worker monitors reconnect to the correct queen after restart. (@bryanadenhq)
|
||||
- **Update meta.json when loading workers** -- Worker metadata stays in sync with runtime state. (@RichardTang-Aden)
|
||||
- **Generate worker MCP file correctly** -- Fix MCP config generation for spawned workers. (@RichardTang-Aden)
|
||||
- **Share event bus so tool events are visible to parent** -- Tool execution events propagate up to parent graphs. (@bryanadenhq)
|
||||
- **Subagent activity tracking in queen status** -- Queen instructions include live subagent status. (@bryanadenhq)
|
||||
- **GCU system prompt updates** -- Auto-snapshots, batching, popup tracking, and close_all guidance. (@bryanadenhq)
|
||||
|
||||
#### Frontend
|
||||
|
||||
- **Loading spinner in draft panel** -- Shows spinner during planning phase instead of blank panel. (@bryanadenhq)
|
||||
- **Fix credential modal errors** -- Modal no longer eats errors; banner stays visible. (@bryanadenhq)
|
||||
- **Fix credentials_required loop** -- Stop clearing the flag on modal close to prevent infinite re-prompting. (@bryanadenhq)
|
||||
- **Fix "Add tab" dropdown overflow** -- Dropdown no longer hidden when many agents are open. (@prasoonmhwr)
|
||||
|
||||
#### Testing
|
||||
|
||||
- **Dummy agent test framework** -- 8 test modules (echo, pipeline, branch, parallel merge, retry, feedback loop, worker, GCU subagent) with shared fixtures and CI runner. (@bryanadenhq)
|
||||
- **Subagent lifecycle tests** -- Validate subagent spawn and completion flows. (@bryanadenhq)
|
||||
|
||||
#### Documentation & Infrastructure
|
||||
|
||||
- **MCP integration PRD** -- Product requirements for MCP server registry. (@TimothyZhang7)
|
||||
- **Skills registry PRD** -- Product requirements for skill registry system. (@bryanadenhq)
|
||||
- **Bounty program updates** -- Standard bounty issue template and updated contributor guide. (@bryanadenhq)
|
||||
- **Windows quickstart** -- Add default context limit for PowerShell setup. (@bryanadenhq)
|
||||
- **Remove deprecated files** -- Clean up `setup_mcp.py`, `verify_mcp.py`, `antigravity-setup.md`, and `setup-antigravity-mcp.sh`. (@bryanadenhq)
|
||||
|
||||
---
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fix credential modal eating errors and banner staying open
|
||||
- Stop clearing `credentials_required` on modal close to prevent infinite loop
|
||||
- Share event bus so tool events are visible to parent graph
|
||||
- Use lazy %-formatting in subagent completion log to avoid f-string in logger
|
||||
- Anchor worker monitoring to queen's session ID on cold-restore
|
||||
- Update meta.json when loading workers
|
||||
- Generate worker MCP file correctly
|
||||
- Fix "Add tab" dropdown partially hidden when creating multiple agents
|
||||
|
||||
---
|
||||
|
||||
### Community Contributors
|
||||
|
||||
- **Prasoon Mahawar** (@prasoonmhwr) -- Fix UI overflow on agent tab dropdown
|
||||
- **Richard Tang** (@RichardTang-Aden) -- Worker MCP generation and meta.json fixes
|
||||
|
||||
---
|
||||
|
||||
### Upgrading
|
||||
|
||||
```bash
|
||||
git pull origin main
|
||||
uv sync
|
||||
```
|
||||
|
||||
The Playwright dependency is no longer required for GCU browser operations. Chrome must be installed on the host system.
|
||||
|
||||
---
|
||||
|
||||
## v0.7.0
|
||||
|
||||
**Release Date:** March 5, 2026
|
||||
**Tag:** v0.7.0
|
||||
|
||||
Session management refactor release.
|
||||
|
||||
---
|
||||
|
||||
## v0.5.1
|
||||
|
||||
**Release Date:** February 18, 2026
|
||||
**Tag:** v0.5.1
|
||||
|
||||
## The Hive Gets a Brain
|
||||
### The Hive Gets a Brain
|
||||
|
||||
v0.5.1 is our most ambitious release yet. Hive agents can now **build other agents** -- the new Hive Coder meta-agent writes, tests, and fixes agent packages from natural language. The runtime grows multi-graph support so one session can orchestrate multiple agents simultaneously. The TUI gets a complete overhaul with an in-app agent picker, live streaming, and seamless escalation to the Coder. And we're now provider-agnostic: Claude Code subscriptions, OpenAI-compatible endpoints, and any LiteLLM-supported model work out of the box.
|
||||
|
||||
---
|
||||
|
||||
## Highlights
|
||||
### Highlights
|
||||
|
||||
### Hive Coder -- The Agent That Builds Agents
|
||||
#### Hive Coder -- The Agent That Builds Agents
|
||||
|
||||
A native meta-agent that lives inside the framework at `core/framework/agents/hive_coder/`. Give it a natural-language specification and it produces a complete agent package -- goal definition, node prompts, edge routing, MCP tool wiring, tests, and all boilerplate files.
|
||||
|
||||
@@ -30,7 +162,7 @@ The Coder ships with:
|
||||
- **Coder Tools MCP server** -- file I/O, fuzzy-match editing, git snapshots, and sandboxed shell execution (`tools/coder_tools_server.py`)
|
||||
- **Test generation** -- structural tests for forever-alive agents that don't hang on `runner.run()`
|
||||
|
||||
### Multi-Graph Agent Runtime
|
||||
#### Multi-Graph Agent Runtime
|
||||
|
||||
`AgentRuntime` now supports loading, managing, and switching between multiple agent graphs within a single session. Six new lifecycle tools give agents (and the TUI) full control:
|
||||
|
||||
@@ -44,7 +176,7 @@ await runtime.add_graph("exports/deep_research_agent")
|
||||
|
||||
The Hive Coder uses multi-graph internally -- when you escalate from a worker agent, the Coder loads as a separate graph while the worker stays alive in the background.
|
||||
|
||||
### TUI Revamp
|
||||
#### TUI Revamp
|
||||
|
||||
The Terminal UI gets a ground-up rebuild with five major additions:
|
||||
|
||||
@@ -54,7 +186,7 @@ The Terminal UI gets a ground-up rebuild with five major additions:
|
||||
- **PDF attachments** -- `/attach` and `/detach` commands with native OS file dialog (macOS, Linux, Windows)
|
||||
- **Multi-graph commands** -- `/graphs`, `/graph <id>`, `/load <path>`, `/unload <id>` for managing agent graphs in-session
|
||||
|
||||
### Provider-Agnostic LLM Support
|
||||
#### Provider-Agnostic LLM Support
|
||||
|
||||
Hive is no longer Anthropic-only. v0.5.1 adds first-class support for:
|
||||
|
||||
@@ -66,9 +198,9 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
|
||||
---
|
||||
|
||||
## What's New
|
||||
### What's New
|
||||
|
||||
### Architecture & Runtime
|
||||
#### Architecture & Runtime
|
||||
|
||||
- **Hive Coder meta-agent** -- Natural-language agent builder with reference docs, guardian watchdog, and `hive code` CLI command. (@TimothyZhang7)
|
||||
- **Multi-graph agent sessions** -- `add_graph`/`remove_graph` on AgentRuntime with 6 lifecycle tools (`load_agent`, `unload_agent`, `start_agent`, `restart_agent`, `list_agents`, `get_user_presence`). (@TimothyZhang7)
|
||||
@@ -79,7 +211,7 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
- **Pre-start confirmation prompt** -- Interactive prompt before agent execution allowing credential updates or abort. (@RichardTang-Aden)
|
||||
- **Event bus multi-graph support** -- `graph_id` on events, `filter_graph` on subscriptions, `ESCALATION_REQUESTED` event type, `exclude_own_graph` filter. (@TimothyZhang7)
|
||||
|
||||
### TUI Improvements
|
||||
#### TUI Improvements
|
||||
|
||||
- **In-app agent picker** (Ctrl+A) -- Tabbed modal for browsing agents with metadata badges (nodes, tools, sessions, tags). (@TimothyZhang7)
|
||||
- **Runtime-optional TUI startup** -- Launches without a pre-loaded agent, shows agent picker on startup. (@TimothyZhang7)
|
||||
@@ -89,7 +221,7 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
- **Multi-graph TUI commands** -- `/graphs`, `/graph <id>`, `/load <path>`, `/unload <id>`. (@TimothyZhang7)
|
||||
- **Agent Guardian watchdog** -- Event-driven monitor that catches secondary agent failures and triggers automatic remediation, with `--no-guardian` CLI flag. (@TimothyZhang7)
|
||||
|
||||
### New Tool Integrations
|
||||
#### New Tool Integrations
|
||||
|
||||
| Tool | Description | Contributor |
|
||||
| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------ |
|
||||
@@ -99,7 +231,7 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
| **Google Docs** | Document creation, reading, and editing with OAuth credential support | @haliaeetusvocifer |
|
||||
| **Gmail enhancements** | Expanded mail operations for inbox management | @bryanadenhq |
|
||||
|
||||
### Infrastructure
|
||||
#### Infrastructure
|
||||
|
||||
- **Default node type → `event_loop`** -- `NodeSpec.node_type` defaults to `"event_loop"` instead of `"llm_tool_use"`. (@TimothyZhang7)
|
||||
- **Default `max_node_visits` → 0 (unlimited)** -- Nodes default to unlimited visits, reducing friction for feedback loops and forever-alive agents. (@TimothyZhang7)
|
||||
@@ -112,7 +244,7 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
|
||||
---
|
||||
|
||||
## Bug Fixes
|
||||
### Bug Fixes
|
||||
|
||||
- Flush WIP accumulator outputs on cancel/failure so edge conditions see correct values on resume
|
||||
- Stall detection state preserved across resume (no more resets on checkpoint restore)
|
||||
@@ -125,13 +257,13 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
- Fix email agent version conflicts (@RichardTang-Aden)
|
||||
- Fix coder tool timeouts (120s for tests, 300s cap for commands)
|
||||
|
||||
## Documentation
|
||||
### Documentation
|
||||
|
||||
- Clarify installation and prevent root pip install misuse (@paarths-collab)
|
||||
|
||||
---
|
||||
|
||||
## Agent Updates
|
||||
### Agent Updates
|
||||
|
||||
- **Email Inbox Management** -- Consolidate `gmail_inbox_guardian` and `inbox_management` into a single unified agent with updated prompts and config. (@RichardTang-Aden, @bryanadenhq)
|
||||
- **Job Hunter** -- Updated node prompts, config, and agent metadata; added PDF resume selection. (@bryanadenhq)
|
||||
@@ -141,7 +273,7 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
|
||||
---
|
||||
|
||||
## Breaking Changes
|
||||
### Breaking Changes
|
||||
|
||||
- **Deprecated node types raise `RuntimeError`** -- `llm_tool_use`, `llm_generate`, `function`, `router`, `human_input` now fail instead of warning. Migrate to `event_loop`.
|
||||
- **`NodeSpec.node_type` defaults to `"event_loop"`** (was `"llm_tool_use"`)
|
||||
@@ -150,7 +282,7 @@ The quickstart script auto-detects Claude Code subscriptions and ZAI Code instal
|
||||
|
||||
---
|
||||
|
||||
## Community Contributors
|
||||
### Community Contributors
|
||||
|
||||
A huge thank you to everyone who contributed to this release:
|
||||
|
||||
@@ -165,14 +297,14 @@ A huge thank you to everyone who contributed to this release:
|
||||
|
||||
---
|
||||
|
||||
## Upgrading
|
||||
### Upgrading
|
||||
|
||||
```bash
|
||||
git pull origin main
|
||||
uv sync
|
||||
```
|
||||
|
||||
### Migration Guide
|
||||
#### Migration Guide
|
||||
|
||||
If your agents use deprecated node types, update them:
|
||||
|
||||
@@ -196,12 +328,3 @@ hive code
|
||||
# Or from TUI -- press Ctrl+E to escalate
|
||||
hive tui
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What's Next
|
||||
|
||||
- **Agent-to-agent communication** -- one agent's output triggers another agent's entry point
|
||||
- **Cost visibility** -- detailed runtime log of LLM costs per node and per session
|
||||
- **Persistent webhook subscriptions** -- survive agent restarts without re-registering
|
||||
- **Remote agent deployment** -- run agents as long-lived services with HTTP APIs
|
||||
|
||||
+8
-3
@@ -121,9 +121,15 @@ uv sync
|
||||
6. Make your changes
|
||||
7. Run checks and tests:
|
||||
```bash
|
||||
make check # Lint and format checks (ruff check + ruff format --check)
|
||||
make check # Lint and format checks
|
||||
make test # Core tests
|
||||
```
|
||||
On Windows (no make), run directly:
|
||||
```powershell
|
||||
uv run ruff check core/ tools/
|
||||
uv run ruff format --check core/ tools/
|
||||
uv run pytest core/tests/
|
||||
```
|
||||
8. Commit your changes following our commit conventions
|
||||
9. Push to your fork and submit a Pull Request
|
||||
|
||||
@@ -222,8 +228,7 @@ else: # linux
|
||||
- **Node.js 18+** (optional, for frontend development)
|
||||
|
||||
> **Windows Users:**
|
||||
> If you are on native Windows, it is recommended to use **WSL (Windows Subsystem for Linux)**.
|
||||
> Alternatively, make sure to run PowerShell or Git Bash with Python 3.11+ installed, and disable "App Execution Aliases" in Windows settings.
|
||||
> Native Windows is supported. Use `.\quickstart.ps1` for setup and `.\hive.ps1` to run (PowerShell 5.1+). Disable "App Execution Aliases" in Windows settings to avoid Python path conflicts. WSL is also an option but not required.
|
||||
|
||||
> **Tip:** Installing Claude Code skills is optional for running existing agents, but required if you plan to **build new agents**.
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
.PHONY: lint format check test install-hooks help frontend-install frontend-dev frontend-build
|
||||
.PHONY: lint format check test test-tools test-live test-all install-hooks help frontend-install frontend-dev frontend-build
|
||||
|
||||
# ── Ensure uv is findable in Git Bash on Windows ──────────────────────────────
|
||||
# uv installs to ~/.local/bin on Windows/Linux/macOS. Git Bash may not include
|
||||
# this in PATH by default, so we prepend it here.
|
||||
export PATH := $(HOME)/.local/bin:$(PATH)
|
||||
|
||||
# ── Targets ───────────────────────────────────────────────────────────────────
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \
|
||||
@@ -46,4 +53,4 @@ frontend-dev: ## Start frontend dev server
|
||||
cd core/frontend && npm run dev
|
||||
|
||||
frontend-build: ## Build frontend for production
|
||||
cd core/frontend && npm run build
|
||||
cd core/frontend && npm run build
|
||||
@@ -27,7 +27,7 @@
|
||||
<img src="https://img.shields.io/badge/Multi--Agent-Systems-blue?style=flat-square" alt="Multi-Agent" />
|
||||
<img src="https://img.shields.io/badge/Headless-Development-purple?style=flat-square" alt="Headless" />
|
||||
<img src="https://img.shields.io/badge/Human--in--the--Loop-orange?style=flat-square" alt="HITL" />
|
||||
<img src="https://img.shields.io/badge/Production--Ready-red?style=flat-square" alt="Production" />
|
||||
<img src="https://img.shields.io/badge/Browser-Use-red?style=flat-square" alt="Browser Use" />
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/OpenAI-supported-412991?style=flat-square&logo=openai" alt="OpenAI" />
|
||||
@@ -37,15 +37,16 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Build autonomous, reliable, self-improving AI agents without hardcoding workflows. Define your goal through conversation with hive coding agent(queen), and the framework generates a node graph with dynamically created connection code. When things break, the framework captures failure data, evolves the agent through the coding agent, and redeploys. Built-in human-in-the-loop nodes, credential management, and real-time monitoring give you control without sacrificing adaptability.
|
||||
Generate a swarm of worker agents with a coding agent(queen) that control them. Define your goal through conversation with hive queen, and the framework generates a node graph with dynamically created connection code. When things break, the framework captures failure data, evolves the agent through the coding agent, and redeploys. Built-in human-in-the-loop nodes, browser use, credential management, and real-time monitoring give you control without sacrificing adaptability.
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
[](https://www.youtube.com/watch?v=XDOG9fOaLjU)
|
||||
https://github.com/user-attachments/assets/aad3a035-e7b3-4cac-b13d-4a83c7002c30
|
||||
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
Hive is designed for developers and teams who want to build **production-grade AI agents** without manually wiring complex workflows.
|
||||
Hive is designed for developers and teams who want to build many **autonomous AI agents** fast without manually wiring complex workflows.
|
||||
|
||||
Hive is a good fit if you:
|
||||
|
||||
@@ -84,7 +85,7 @@ Use Hive when you need:
|
||||
- An LLM provider that powers the agents
|
||||
- **ripgrep (optional, recommended on Windows):** The `search_files` tool uses ripgrep for faster file search. If not installed, a Python fallback is used. On Windows: `winget install BurntSushi.ripgrep` or `scoop install ripgrep`
|
||||
|
||||
> **Note for Windows Users:** It is strongly recommended to use **WSL (Windows Subsystem for Linux)** or **Git Bash** to run this framework. Some core automation scripts may not execute correctly in standard Command Prompt or PowerShell.
|
||||
> **Windows Users:** Native Windows is supported via `quickstart.ps1` and `hive.ps1`. Run these in PowerShell 5.1+. WSL is also an option but not required.
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -115,11 +116,9 @@ This sets up:
|
||||
|
||||
> **Tip:** To reopen the dashboard later, run `hive open` from the project directory.
|
||||
|
||||
<img width="2500" height="1214" alt="home-screen" src="https://github.com/user-attachments/assets/134d897f-5e75-4874-b00b-e0505f6b45c4" />
|
||||
|
||||
### Build Your First Agent
|
||||
|
||||
Type the agent you want to build in the home input box
|
||||
Type the agent you want to build in the home input box. The queen is going to ask you questions and work out a solution with you.
|
||||
|
||||
<img width="2500" height="1214" alt="Image" src="https://github.com/user-attachments/assets/1ce19141-a78b-46f5-8d64-dbf987e048f4" />
|
||||
|
||||
@@ -131,7 +130,7 @@ Click "Try a sample agent" and check the templates. You can run a template direc
|
||||
|
||||
Now you can run an agent by selecting the agent (either an existing agent or example agent). You can click the Run button on the top left, or talk to the queen agent and it can run the agent for you.
|
||||
|
||||
<img width="2500" height="1214" alt="Image" src="https://github.com/user-attachments/assets/71c38206-2ad5-49aa-bde8-6698d0bc55f5" />
|
||||
<img width="2549" height="1174" alt="Screenshot 2026-03-12 at 9 27 36 PM" src="https://github.com/user-attachments/assets/7c7d30fa-9ceb-4c23-95af-b1caa405547d" />
|
||||
|
||||
## Features
|
||||
|
||||
@@ -143,7 +142,6 @@ Now you can run an agent by selecting the agent (either an existing agent or exa
|
||||
- **SDK-Wrapped Nodes** - Every node gets shared memory, local RLM memory, monitoring, tools, and LLM access out of the box
|
||||
- **[Human-in-the-Loop](docs/key_concepts/graph.md#human-in-the-loop)** - Intervention nodes that pause execution for human input with configurable timeouts and escalation
|
||||
- **Real-time Observability** - WebSocket streaming for live monitoring of agent execution, decisions, and node-to-node communication
|
||||
- **Production-Ready** - Self-hostable, built for scale and reliability
|
||||
|
||||
## Integration
|
||||
|
||||
@@ -392,10 +390,6 @@ Hive generates your entire agent system from natural language goals using a codi
|
||||
|
||||
Yes, Hive is fully open-source under the Apache License 2.0. We actively encourage community contributions and collaboration.
|
||||
|
||||
**Q: Can Hive handle complex, production-scale use cases?**
|
||||
|
||||
Yes. Hive is explicitly designed for production environments with features like automatic failure recovery, real-time observability, cost controls, and horizontal scaling support. The framework handles both simple automations and complex multi-agent workflows.
|
||||
|
||||
**Q: Does Hive support human-in-the-loop workflows?**
|
||||
|
||||
Yes, Hive fully supports [human-in-the-loop](docs/key_concepts/graph.md#human-in-the-loop) workflows through intervention nodes that pause execution for human input. These include configurable timeouts and escalation policies, allowing seamless collaboration between human experts and AI agents.
|
||||
@@ -420,6 +414,16 @@ Visit [docs.adenhq.com](https://docs.adenhq.com/) for complete guides, API refer
|
||||
|
||||
Contributions are welcome! Fork the repository, create your feature branch, implement your changes, and submit a pull request. See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines.
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#aden-hive/hive&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=aden-hive/hive&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=aden-hive/hive&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=aden-hive/hive&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
perf: reduce subprocess spawning in quickstart scripts (#4427)
|
||||
|
||||
## Problem
|
||||
Windows process creation (CreateProcess) is 10-100x slower than Linux fork/exec.
|
||||
The quickstart scripts were spawning 4+ separate `uv run python -c "import X"`
|
||||
processes to verify imports, adding ~600ms overhead on Windows.
|
||||
|
||||
## Solution
|
||||
Consolidated all import checks into a single batch script that checks multiple
|
||||
modules in one subprocess call, reducing spawn overhead by ~75%.
|
||||
|
||||
## Changes
|
||||
- **New**: `scripts/check_requirements.py` - Batched import checker
|
||||
- **New**: `scripts/test_check_requirements.py` - Test suite
|
||||
- **New**: `scripts/benchmark_quickstart.ps1` - Performance benchmark tool
|
||||
- **Modified**: `quickstart.ps1` - Updated import verification (2 sections)
|
||||
- **Modified**: `quickstart.sh` - Updated import verification
|
||||
|
||||
## Performance Impact
|
||||
**Benchmark results on Windows:**
|
||||
- Before: ~19.8 seconds for import checks
|
||||
- After: ~4.9 seconds for import checks
|
||||
- **Improvement: 14.9 seconds saved (75.2% faster)**
|
||||
|
||||
## Testing
|
||||
- ✅ All functional tests pass (`scripts/test_check_requirements.py`)
|
||||
- ✅ Quickstart scripts work correctly on Windows
|
||||
- ✅ Error handling verified (invalid imports reported correctly)
|
||||
- ✅ Performance benchmark confirms 75%+ improvement
|
||||
|
||||
Fixes #4427
|
||||
@@ -1,27 +0,0 @@
|
||||
# Identity mapping: GitHub username -> Discord ID
|
||||
#
|
||||
# This file links GitHub accounts to Discord accounts for the
|
||||
# Integration Bounty Program. When a bounty PR is merged, the
|
||||
# GitHub Action uses this file to ping the contributor on Discord.
|
||||
#
|
||||
# HOW TO ADD YOURSELF:
|
||||
# Open a "Link Discord Account" issue:
|
||||
# https://github.com/aden-hive/hive/issues/new?template=link-discord.yml
|
||||
# A GitHub Action will automatically add your entry here.
|
||||
#
|
||||
# To find your Discord ID:
|
||||
# 1. Open Discord Settings > Advanced > Enable Developer Mode
|
||||
# 2. Right-click your name > Copy User ID
|
||||
#
|
||||
# Format:
|
||||
# - github: your-github-username
|
||||
# discord: "your-discord-id" # quotes required (it's a number)
|
||||
# name: Your Display Name # optional
|
||||
|
||||
contributors:
|
||||
# - github: example-user
|
||||
# discord: "123456789012345678"
|
||||
# name: Example User
|
||||
- github: TimothyZhang7
|
||||
discord: "408460790061072384"
|
||||
name: Timothy@Aden
|
||||
@@ -1,740 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
EventLoopNode WebSocket Demo
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams EventLoopNode execution to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/event_loop_wss_demo.py
|
||||
|
||||
Then open http://localhost:8765 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
import os # noqa: E402
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state (shared across WebSocket connections)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_demo_"))
|
||||
STORE = FileConversationStore(STORE_DIR / "conversation")
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — real tools via ToolRegistry (same pattern as GraphExecutor)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
# Credential store: Aden sync (OAuth2 tokens) + encrypted files + env var fallback
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_local_storage = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
|
||||
if os.environ.get("ADEN_API_KEY"):
|
||||
try:
|
||||
from framework.credentials.aden import ( # noqa: E402
|
||||
AdenCachedStorage,
|
||||
AdenClientConfig,
|
||||
AdenCredentialClient,
|
||||
AdenSyncProvider,
|
||||
)
|
||||
|
||||
_client = AdenCredentialClient(AdenClientConfig(base_url="https://api.adenhq.com"))
|
||||
_provider = AdenSyncProvider(client=_client)
|
||||
_storage = AdenCachedStorage(
|
||||
local_storage=_local_storage,
|
||||
aden_provider=_provider,
|
||||
)
|
||||
_cred_store = CredentialStore(storage=_storage, providers=[_provider], auto_refresh=True)
|
||||
_synced = _provider.sync_all(_cred_store)
|
||||
logger.info("Synced %d credentials from Aden", _synced)
|
||||
except Exception as e:
|
||||
logger.warning("Aden sync unavailable: %s", e)
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
else:
|
||||
logger.info("ADEN_API_KEY not set, using local credential storage")
|
||||
_cred_store = CredentialStore(storage=_local_storage)
|
||||
|
||||
CREDENTIALS = CredentialStoreAdapter(_cred_store)
|
||||
|
||||
# Debug: log which credentials resolved
|
||||
for _name in ["brave_search", "hubspot", "anthropic"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# --- web_search (Brave Search API) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={"X-Subscription-Token": api_key, "Accept": "application/json"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
# --- web_scrape (httpx + BeautifulSoup, no playwright for sync compat) ---
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(url, timeout=30.0, follow_redirects=True, headers=_SCRAPE_HEADERS)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {"url": url, "title": title, "content": text, "length": len(text)}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
# --- HubSpot CRM tools (optional, requires HUBSPOT_ACCESS_TOKEN) ---
|
||||
|
||||
_HUBSPOT_API = "https://api.hubapi.com"
|
||||
|
||||
|
||||
def _hubspot_headers() -> dict | None:
|
||||
token = CREDENTIALS.get("hubspot")
|
||||
if token:
|
||||
logger.debug("HubSpot token: %s...%s (len=%d)", token[:8], token[-4:], len(token))
|
||||
else:
|
||||
logger.debug("HubSpot token: not found")
|
||||
if not token:
|
||||
return None
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _exec_hubspot_search(inputs: dict) -> dict:
|
||||
headers = _hubspot_headers()
|
||||
if not headers:
|
||||
return {"error": "HUBSPOT_ACCESS_TOKEN not set"}
|
||||
object_type = inputs.get("object_type", "contacts")
|
||||
query = inputs.get("query", "")
|
||||
limit = min(inputs.get("limit", 10), 100)
|
||||
body: dict = {"limit": limit}
|
||||
if query:
|
||||
body["query"] = query
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{_HUBSPOT_API}/crm/v3/objects/{object_type}/search",
|
||||
headers=headers,
|
||||
json=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HubSpot API HTTP {resp.status_code}: {resp.text[:200]}"}
|
||||
return resp.json()
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"HubSpot error: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="hubspot_search",
|
||||
tool=Tool(
|
||||
name="hubspot_search",
|
||||
description=(
|
||||
"Search HubSpot CRM objects (contacts, companies, or deals). "
|
||||
"Returns matching records with their properties."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"object_type": {
|
||||
"type": "string",
|
||||
"description": "CRM object type: 'contacts', 'companies', or 'deals'",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (name, email, domain, etc.)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (1-100, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["object_type"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_hubspot_search(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page (embedded)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>EventLoopNode Live Demo</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117; color: #c9d1d9;
|
||||
height: 100vh; display: flex; flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22; padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex; align-items: center; gap: 16px;
|
||||
}
|
||||
header h1 { font-size: 16px; color: #58a6ff; font-weight: 600; }
|
||||
.status {
|
||||
font-size: 12px; padding: 3px 10px; border-radius: 12px;
|
||||
background: #21262d; color: #8b949e;
|
||||
}
|
||||
.status.running { background: #1a4b2e; color: #3fb950; }
|
||||
.status.done { background: #1a3a5c; color: #58a6ff; }
|
||||
.status.error { background: #4b1a1a; color: #f85149; }
|
||||
.chat { flex: 1; overflow-y: auto; padding: 16px; }
|
||||
.msg {
|
||||
margin: 8px 0; padding: 10px 14px; border-radius: 8px;
|
||||
line-height: 1.6; white-space: pre-wrap; word-wrap: break-word;
|
||||
}
|
||||
.msg.user { background: #1a3a5c; color: #58a6ff; }
|
||||
.msg.assistant { background: #161b22; color: #c9d1d9; }
|
||||
.msg.event {
|
||||
background: transparent; color: #8b949e; font-size: 11px;
|
||||
padding: 4px 14px; border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop { border-left-color: #58a6ff; }
|
||||
.msg.event.tool { border-left-color: #d29922; }
|
||||
.msg.event.stall { border-left-color: #f85149; }
|
||||
.input-bar {
|
||||
padding: 12px 16px; background: #161b22;
|
||||
border-top: 1px solid #30363d; display: flex; gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1; background: #0d1117; border: 1px solid #30363d;
|
||||
color: #c9d1d9; padding: 8px 12px; border-radius: 6px;
|
||||
font-family: inherit; font-size: 14px; outline: none;
|
||||
}
|
||||
.input-bar input:focus { border-color: #58a6ff; }
|
||||
.input-bar button {
|
||||
background: #238636; color: #fff; border: none;
|
||||
padding: 8px 20px; border-radius: 6px; cursor: pointer;
|
||||
font-family: inherit; font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover { background: #2ea043; }
|
||||
.input-bar button:disabled {
|
||||
background: #21262d; color: #484f58; cursor: not-allowed;
|
||||
}
|
||||
.input-bar button.clear { background: #da3633; }
|
||||
.input-bar button.clear:hover { background: #f85149; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>EventLoopNode Live</h1>
|
||||
<span id="status" class="status">Idle</span>
|
||||
<span id="iter" class="status" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Ask anything..." autofocus />
|
||||
<button id="go" onclick="run()">Send</button>
|
||||
<button class="clear"
|
||||
onclick="clearConversation()">Clear</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
const chat = document.getElementById('chat');
|
||||
const status = document.getElementById('status');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setStatus(text, cls) {
|
||||
status.textContent = text;
|
||||
status.className = 'status ' + cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setStatus('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setStatus('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'ready') {
|
||||
setStatus('Ready', 'done');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg('RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls);
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
}
|
||||
else if (evt.type === 'result') {
|
||||
setStatus('Session ended', evt.success ? 'done' : 'error');
|
||||
if (evt.error) addMsg('ERROR ' + evt.error, 'event stall');
|
||||
if (currentAssistantEl && !currentAssistantEl.textContent)
|
||||
currentAssistantEl.remove();
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
else if (evt.type === 'cleared') {
|
||||
chat.innerHTML = '';
|
||||
iterCount = 0;
|
||||
iterEl.textContent = 'Step 0';
|
||||
iterEl.style.display = 'none';
|
||||
setStatus('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
setStatus('Running', 'running');
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
function clearConversation() {
|
||||
if (ws && ws.readyState === 1) {
|
||||
ws.send(JSON.stringify({ command: 'clear' }));
|
||||
}
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Persistent WebSocket: long-lived EventLoopNode with client_facing blocking."""
|
||||
global STORE
|
||||
|
||||
# -- Event forwarding (WebSocket ← EventBus) ----------------------------
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
# -- Per-connection state -----------------------------------------------
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
node_spec = NodeSpec(
|
||||
id="assistant",
|
||||
name="Chat Assistant",
|
||||
description="A conversational assistant that remembers context across messages",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to tools. "
|
||||
"You can search the web, scrape webpages, and query HubSpot CRM. "
|
||||
"Use tools when the user asks for current information or external data. "
|
||||
"You have full conversation history, so you can reference previous messages."
|
||||
),
|
||||
)
|
||||
|
||||
# -- Ready callback: subscribe to CLIENT_INPUT_REQUESTED on the bus ---
|
||||
async def on_input_requested(event):
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "ready"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=on_input_requested,
|
||||
)
|
||||
|
||||
async def start_loop(first_message: str):
|
||||
"""Create an EventLoopNode and run it as a background task."""
|
||||
nonlocal node, loop_task
|
||||
|
||||
memory = SharedMemory()
|
||||
ctx = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="assistant",
|
||||
node_spec=node_spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
node = EventLoopNode(
|
||||
event_bus=bus,
|
||||
config=LoopConfig(max_iterations=10_000, max_context_tokens=32_000),
|
||||
conversation_store=STORE,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
await node.inject_event(first_message)
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": result.success,
|
||||
"output": result.output,
|
||||
"error": result.error,
|
||||
"tokens": result.tokens_used,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Loop ended: success={result.success}, tokens={result.tokens_used}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("Loop stopped: WebSocket closed")
|
||||
except Exception as e:
|
||||
logger.exception("Loop error")
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "result",
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"output": {},
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
loop_task = asyncio.create_task(_run())
|
||||
|
||||
async def stop_loop():
|
||||
"""Signal the node and wait for the loop task to finish."""
|
||||
nonlocal node, loop_task
|
||||
if loop_task and not loop_task.done():
|
||||
if node:
|
||||
node.signal_shutdown()
|
||||
try:
|
||||
await asyncio.wait_for(loop_task, timeout=5.0)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
loop_task.cancel()
|
||||
node = None
|
||||
loop_task = None
|
||||
|
||||
# -- Message loop (runs for the lifetime of this WebSocket) -------------
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Clear command
|
||||
if msg.get("command") == "clear":
|
||||
import shutil
|
||||
|
||||
await stop_loop()
|
||||
await STORE.close()
|
||||
conv_dir = STORE_DIR / "conversation"
|
||||
if conv_dir.exists():
|
||||
shutil.rmtree(conv_dir)
|
||||
STORE = FileConversationStore(conv_dir)
|
||||
await websocket.send(json.dumps({"type": "cleared"}))
|
||||
logger.info("Conversation cleared")
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
if node is None:
|
||||
# First message — spin up the loop
|
||||
logger.info(f"Starting persistent loop: {topic}")
|
||||
await start_loop(topic)
|
||||
else:
|
||||
# Subsequent message — inject into the running loop
|
||||
logger.info(f"Injecting message: {topic}")
|
||||
await node.inject_event(topic)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await stop_loop()
|
||||
logger.info("WebSocket closed, loop stopped")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler for serving the HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None # let websockets handle the upgrade
|
||||
# Serve the HTML page for any other path
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8765
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Demo running at http://localhost:{port}")
|
||||
logger.info("Open in your browser and enter a topic to research.")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,930 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Two-Node ContextHandoff Demo
|
||||
|
||||
Demonstrates ContextHandoff between two EventLoopNode instances:
|
||||
Node A (Researcher) → ContextHandoff → Node B (Analyst)
|
||||
|
||||
Real LLM, real FileConversationStore, real EventBus.
|
||||
Streams both nodes to a browser via WebSocket.
|
||||
|
||||
Usage:
|
||||
cd /home/timothy/oss/hive/core
|
||||
python demos/handoff_demo.py
|
||||
|
||||
Then open http://localhost:8766 in your browser.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from bs4 import BeautifulSoup
|
||||
from websockets.http11 import Request, Response
|
||||
|
||||
# Add core, tools, and hive root to path
|
||||
_CORE_DIR = Path(__file__).resolve().parent.parent
|
||||
_HIVE_DIR = _CORE_DIR.parent
|
||||
sys.path.insert(0, str(_CORE_DIR)) # framework.*
|
||||
sys.path.insert(0, str(_HIVE_DIR / "tools" / "src")) # aden_tools.*
|
||||
sys.path.insert(0, str(_HIVE_DIR)) # core.framework.* (for aden_tools imports)
|
||||
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS, CredentialStoreAdapter # noqa: E402
|
||||
from core.framework.credentials import CredentialStore # noqa: E402
|
||||
|
||||
from framework.credentials.storage import ( # noqa: E402
|
||||
CompositeStorage,
|
||||
EncryptedFileStorage,
|
||||
EnvVarStorage,
|
||||
)
|
||||
from framework.graph.context_handoff import ContextHandoff # noqa: E402
|
||||
from framework.graph.conversation import NodeConversation # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.provider import Tool # noqa: E402
|
||||
from framework.runner.tool_registry import ToolRegistry # noqa: E402
|
||||
from framework.runtime.core import Runtime # noqa: E402
|
||||
from framework.runtime.event_bus import EventBus, EventType # noqa: E402
|
||||
from framework.storage.conversation_store import FileConversationStore # noqa: E402
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
||||
logger = logging.getLogger("handoff_demo")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Persistent state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
STORE_DIR = Path(tempfile.mkdtemp(prefix="hive_handoff_"))
|
||||
RUNTIME = Runtime(STORE_DIR / "runtime")
|
||||
LLM = LiteLLMProvider(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Credentials
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Composite credential store: encrypted files (primary) + env vars (fallback)
|
||||
_env_mapping = {name: spec.env_var for name, spec in CREDENTIAL_SPECS.items()}
|
||||
_composite = CompositeStorage(
|
||||
primary=EncryptedFileStorage(),
|
||||
fallbacks=[EnvVarStorage(env_mapping=_env_mapping)],
|
||||
)
|
||||
CREDENTIALS = CredentialStoreAdapter(CredentialStore(storage=_composite))
|
||||
|
||||
for _name in ["brave_search", "hubspot"]:
|
||||
_val = CREDENTIALS.get(_name)
|
||||
if _val:
|
||||
logger.debug("credential %s: OK (len=%d)", _name, len(_val))
|
||||
else:
|
||||
logger.debug("credential %s: not found", _name)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tool Registry — web_search + web_scrape for Node A (Researcher)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def _exec_web_search(inputs: dict) -> dict:
|
||||
api_key = CREDENTIALS.get("brave_search")
|
||||
if not api_key:
|
||||
return {"error": "brave_search credential not configured"}
|
||||
query = inputs.get("query", "")
|
||||
num_results = min(inputs.get("num_results", 10), 20)
|
||||
resp = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": num_results},
|
||||
headers={
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"Brave API HTTP {resp.status_code}"}
|
||||
data = resp.json()
|
||||
results = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("description", ""),
|
||||
}
|
||||
for item in data.get("web", {}).get("results", [])[:num_results]
|
||||
]
|
||||
return {"query": query, "results": results, "total": len(results)}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_search",
|
||||
tool=Tool(
|
||||
name="web_search",
|
||||
description=(
|
||||
"Search the web for current information. "
|
||||
"Returns titles, URLs, and snippets from search results."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (1-500 characters)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (1-20, default 10)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_search(inputs),
|
||||
)
|
||||
|
||||
_SCRAPE_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml",
|
||||
}
|
||||
|
||||
|
||||
def _exec_web_scrape(inputs: dict) -> dict:
|
||||
url = inputs.get("url", "")
|
||||
max_length = max(1000, min(inputs.get("max_length", 50000), 500000))
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
headers=_SCRAPE_HEADERS,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return {"error": f"HTTP {resp.status_code}"}
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
title = soup.title.get_text(strip=True) if soup.title else ""
|
||||
main = (
|
||||
soup.find("article")
|
||||
or soup.find("main")
|
||||
or soup.find(attrs={"role": "main"})
|
||||
or soup.find("body")
|
||||
)
|
||||
text = main.get_text(separator=" ", strip=True) if main else ""
|
||||
text = " ".join(text.split())
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length] + "..."
|
||||
return {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": text,
|
||||
"length": len(text),
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except Exception as e:
|
||||
return {"error": f"Scrape failed: {e}"}
|
||||
|
||||
|
||||
TOOL_REGISTRY.register(
|
||||
name="web_scrape",
|
||||
tool=Tool(
|
||||
name="web_scrape",
|
||||
description=(
|
||||
"Scrape and extract text content from a webpage URL. "
|
||||
"Returns the page title and main text content."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to scrape",
|
||||
},
|
||||
"max_length": {
|
||||
"type": "integer",
|
||||
"description": "Maximum text length (default 50000)",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
executor=lambda inputs: _exec_web_scrape(inputs),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ToolRegistry loaded: %s",
|
||||
", ".join(TOOL_REGISTRY.get_registered_names()),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Node Specs
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
RESEARCHER_SPEC = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches a topic using web search and scraping tools",
|
||||
node_type="event_loop",
|
||||
input_keys=["topic"],
|
||||
output_keys=["research_summary"],
|
||||
system_prompt=(
|
||||
"You are a thorough research assistant. Your job is to research "
|
||||
"the given topic using the web_search and web_scrape tools.\n\n"
|
||||
"1. Search for relevant information on the topic\n"
|
||||
"2. Scrape 1-2 of the most promising URLs for details\n"
|
||||
"3. Synthesize your findings into a comprehensive summary\n"
|
||||
"4. Use set_output with key='research_summary' to save your "
|
||||
"findings\n\n"
|
||||
"Be thorough but efficient. Aim for 2-4 search/scrape calls, "
|
||||
"then summarize and set_output."
|
||||
),
|
||||
)
|
||||
|
||||
ANALYST_SPEC = NodeSpec(
|
||||
id="analyst",
|
||||
name="Analyst",
|
||||
description="Analyzes research findings and provides insights",
|
||||
node_type="event_loop",
|
||||
input_keys=["context"],
|
||||
output_keys=["analysis"],
|
||||
system_prompt=(
|
||||
"You are a strategic analyst. You receive research findings from "
|
||||
"a previous researcher and must:\n\n"
|
||||
"1. Identify key themes and patterns\n"
|
||||
"2. Assess the reliability and significance of the findings\n"
|
||||
"3. Provide actionable insights and recommendations\n"
|
||||
"4. Use set_output with key='analysis' to save your analysis\n\n"
|
||||
"Be concise but insightful. Focus on what matters most."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTML page
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
HTML_PAGE = ( # noqa: E501
|
||||
"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>ContextHandoff Demo</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: 'SF Mono', 'Fira Code', monospace;
|
||||
background: #0d1117;
|
||||
color: #c9d1d9;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
header {
|
||||
background: #161b22;
|
||||
padding: 12px 20px;
|
||||
border-bottom: 1px solid #30363d;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
header h1 {
|
||||
font-size: 16px;
|
||||
color: #58a6ff;
|
||||
font-weight: 600;
|
||||
}
|
||||
.badge {
|
||||
font-size: 12px;
|
||||
padding: 3px 10px;
|
||||
border-radius: 12px;
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.researcher {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.badge.analyst {
|
||||
background: #1a4b2e;
|
||||
color: #3fb950;
|
||||
}
|
||||
.badge.handoff {
|
||||
background: #3d1f00;
|
||||
color: #d29922;
|
||||
}
|
||||
.badge.done {
|
||||
background: #21262d;
|
||||
color: #8b949e;
|
||||
}
|
||||
.badge.error {
|
||||
background: #4b1a1a;
|
||||
color: #f85149;
|
||||
}
|
||||
.chat {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
}
|
||||
.msg {
|
||||
margin: 8px 0;
|
||||
padding: 10px 14px;
|
||||
border-radius: 8px;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.msg.user {
|
||||
background: #1a3a5c;
|
||||
color: #58a6ff;
|
||||
}
|
||||
.msg.assistant {
|
||||
background: #161b22;
|
||||
color: #c9d1d9;
|
||||
}
|
||||
.msg.assistant.analyst-msg {
|
||||
border-left: 3px solid #3fb950;
|
||||
}
|
||||
.msg.event {
|
||||
background: transparent;
|
||||
color: #8b949e;
|
||||
font-size: 11px;
|
||||
padding: 4px 14px;
|
||||
border-left: 3px solid #30363d;
|
||||
}
|
||||
.msg.event.loop {
|
||||
border-left-color: #58a6ff;
|
||||
}
|
||||
.msg.event.tool {
|
||||
border-left-color: #d29922;
|
||||
}
|
||||
.msg.event.stall {
|
||||
border-left-color: #f85149;
|
||||
}
|
||||
.handoff-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #1c1200;
|
||||
border: 1px solid #d29922;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.handoff-banner h3 {
|
||||
color: #d29922;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.handoff-banner p, .result-banner p {
|
||||
color: #8b949e;
|
||||
font-size: 12px;
|
||||
line-height: 1.5;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap;
|
||||
text-align: left;
|
||||
}
|
||||
.result-banner {
|
||||
margin: 16px 0;
|
||||
padding: 16px;
|
||||
background: #0a2614;
|
||||
border: 1px solid #3fb950;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.result-banner h3 {
|
||||
color: #3fb950;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.result-banner .label {
|
||||
color: #58a6ff;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
.result-banner .tokens {
|
||||
color: #484f58;
|
||||
font-size: 11px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.input-bar {
|
||||
padding: 12px 16px;
|
||||
background: #161b22;
|
||||
border-top: 1px solid #30363d;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
.input-bar input {
|
||||
flex: 1;
|
||||
background: #0d1117;
|
||||
border: 1px solid #30363d;
|
||||
color: #c9d1d9;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
font-family: inherit;
|
||||
font-size: 14px;
|
||||
outline: none;
|
||||
}
|
||||
.input-bar input:focus {
|
||||
border-color: #58a6ff;
|
||||
}
|
||||
.input-bar button {
|
||||
background: #238636;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 20px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
font-weight: 600;
|
||||
}
|
||||
.input-bar button:hover {
|
||||
background: #2ea043;
|
||||
}
|
||||
.input-bar button:disabled {
|
||||
background: #21262d;
|
||||
color: #484f58;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>ContextHandoff Demo</h1>
|
||||
<span id="phase" class="badge">Idle</span>
|
||||
<span id="iter" class="badge" style="display:none">Step 0</span>
|
||||
</header>
|
||||
<div id="chat" class="chat"></div>
|
||||
<div class="input-bar">
|
||||
<input id="input" type="text"
|
||||
placeholder="Enter a research topic..." autofocus />
|
||||
<button id="go" onclick="run()">Research</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let currentAssistantEl = null;
|
||||
let iterCount = 0;
|
||||
let currentPhase = 'idle';
|
||||
const chat = document.getElementById('chat');
|
||||
const phase = document.getElementById('phase');
|
||||
const iterEl = document.getElementById('iter');
|
||||
const goBtn = document.getElementById('go');
|
||||
const inputEl = document.getElementById('input');
|
||||
|
||||
inputEl.addEventListener('keydown', e => {
|
||||
if (e.key === 'Enter') run();
|
||||
});
|
||||
|
||||
function setPhase(text, cls) {
|
||||
phase.textContent = text;
|
||||
phase.className = 'badge ' + cls;
|
||||
currentPhase = cls;
|
||||
}
|
||||
|
||||
function addMsg(text, cls) {
|
||||
const el = document.createElement('div');
|
||||
el.className = 'msg ' + cls;
|
||||
el.textContent = text;
|
||||
chat.appendChild(el);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
return el;
|
||||
}
|
||||
|
||||
function addHandoffBanner(summary) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'handoff-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Context Handoff: Researcher -> Analyst';
|
||||
const p = document.createElement('p');
|
||||
p.textContent = summary || 'Passing research context...';
|
||||
banner.appendChild(h3);
|
||||
banner.appendChild(p);
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function addResultBanner(researcher, analyst, tokens) {
|
||||
const banner = document.createElement('div');
|
||||
banner.className = 'result-banner';
|
||||
const h3 = document.createElement('h3');
|
||||
h3.textContent = 'Pipeline Complete';
|
||||
banner.appendChild(h3);
|
||||
|
||||
if (researcher && researcher.research_summary) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'RESEARCH SUMMARY';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = researcher.research_summary;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (analyst && analyst.analysis) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'label';
|
||||
lbl.textContent = 'ANALYSIS';
|
||||
lbl.style.color = '#3fb950';
|
||||
banner.appendChild(lbl);
|
||||
const p = document.createElement('p');
|
||||
p.textContent = analyst.analysis;
|
||||
banner.appendChild(p);
|
||||
}
|
||||
|
||||
if (tokens) {
|
||||
const t = document.createElement('div');
|
||||
t.className = 'tokens';
|
||||
t.textContent = 'Total tokens: ' + tokens.toLocaleString();
|
||||
banner.appendChild(t);
|
||||
}
|
||||
|
||||
chat.appendChild(banner);
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://' + location.host + '/ws');
|
||||
ws.onopen = () => {
|
||||
setPhase('Ready', 'done');
|
||||
goBtn.disabled = false;
|
||||
};
|
||||
ws.onmessage = handleEvent;
|
||||
ws.onerror = () => { setPhase('Error', 'error'); };
|
||||
ws.onclose = () => {
|
||||
setPhase('Reconnecting...', '');
|
||||
goBtn.disabled = true;
|
||||
setTimeout(connect, 2000);
|
||||
};
|
||||
}
|
||||
|
||||
function handleEvent(msg) {
|
||||
const evt = JSON.parse(msg.data);
|
||||
|
||||
if (evt.type === 'phase') {
|
||||
if (evt.phase === 'researcher') {
|
||||
setPhase('Researcher', 'researcher');
|
||||
} else if (evt.phase === 'handoff') {
|
||||
setPhase('Handoff', 'handoff');
|
||||
} else if (evt.phase === 'analyst') {
|
||||
setPhase('Analyst', 'analyst');
|
||||
}
|
||||
iterCount = 0;
|
||||
iterEl.style.display = 'none';
|
||||
}
|
||||
else if (evt.type === 'llm_text_delta') {
|
||||
if (currentAssistantEl) {
|
||||
currentAssistantEl.textContent += evt.content;
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'node_loop_iteration') {
|
||||
iterCount = evt.iteration || (iterCount + 1);
|
||||
iterEl.textContent = 'Step ' + iterCount;
|
||||
iterEl.style.display = '';
|
||||
}
|
||||
else if (evt.type === 'tool_call_started') {
|
||||
var info = evt.tool_name + '('
|
||||
+ JSON.stringify(evt.tool_input).slice(0, 120) + ')';
|
||||
addMsg('TOOL ' + info, 'event tool');
|
||||
}
|
||||
else if (evt.type === 'tool_call_completed') {
|
||||
var preview = (evt.result || '').slice(0, 200);
|
||||
var cls = evt.is_error ? 'stall' : 'tool';
|
||||
addMsg(
|
||||
'RESULT ' + evt.tool_name + ': ' + preview,
|
||||
'event ' + cls
|
||||
);
|
||||
var assistCls = currentPhase === 'analyst'
|
||||
? 'assistant analyst-msg' : 'assistant';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'handoff_context') {
|
||||
addHandoffBanner(evt.summary);
|
||||
var assistCls = 'assistant analyst-msg';
|
||||
currentAssistantEl = addMsg('', assistCls);
|
||||
}
|
||||
else if (evt.type === 'node_result') {
|
||||
if (evt.node_id === 'researcher') {
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (evt.type === 'done') {
|
||||
setPhase('Done', 'done');
|
||||
iterEl.style.display = 'none';
|
||||
if (currentAssistantEl
|
||||
&& !currentAssistantEl.textContent) {
|
||||
currentAssistantEl.remove();
|
||||
}
|
||||
currentAssistantEl = null;
|
||||
addResultBanner(
|
||||
evt.researcher, evt.analyst, evt.total_tokens
|
||||
);
|
||||
goBtn.disabled = false;
|
||||
inputEl.placeholder = 'Enter another topic...';
|
||||
}
|
||||
else if (evt.type === 'error') {
|
||||
setPhase('Error', 'error');
|
||||
addMsg('ERROR ' + evt.message, 'event stall');
|
||||
goBtn.disabled = false;
|
||||
}
|
||||
else if (evt.type === 'node_stalled') {
|
||||
addMsg('STALLED ' + evt.reason, 'event stall');
|
||||
}
|
||||
}
|
||||
|
||||
function run() {
|
||||
const text = inputEl.value.trim();
|
||||
if (!text || !ws || ws.readyState !== 1) return;
|
||||
chat.innerHTML = '';
|
||||
addMsg(text, 'user');
|
||||
currentAssistantEl = addMsg('', 'assistant');
|
||||
inputEl.value = '';
|
||||
goBtn.disabled = true;
|
||||
ws.send(JSON.stringify({ topic: text }));
|
||||
}
|
||||
|
||||
connect();
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# WebSocket handler — sequential Node A → Handoff → Node B
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_ws(websocket):
|
||||
"""Run the two-node handoff pipeline per user message."""
|
||||
try:
|
||||
async for raw in websocket:
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
topic = msg.get("topic", "")
|
||||
if not topic:
|
||||
continue
|
||||
|
||||
logger.info(f"Starting handoff pipeline for: {topic}")
|
||||
|
||||
try:
|
||||
await _run_pipeline(websocket, topic)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("WebSocket closed during pipeline")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Pipeline error")
|
||||
try:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(e)}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_pipeline(websocket, topic: str):
|
||||
"""Execute: Node A (research) → ContextHandoff → Node B (analysis)."""
|
||||
import shutil
|
||||
|
||||
# Fresh stores for each run
|
||||
run_dir = Path(tempfile.mkdtemp(prefix="hive_run_", dir=STORE_DIR))
|
||||
store_a = FileConversationStore(run_dir / "node_a")
|
||||
store_b = FileConversationStore(run_dir / "node_b")
|
||||
|
||||
# Shared event bus
|
||||
bus = EventBus()
|
||||
|
||||
async def forward_event(event):
|
||||
try:
|
||||
payload = {"type": event.type.value, **event.data}
|
||||
if event.node_id:
|
||||
payload["node_id"] = event.node_id
|
||||
await websocket.send(json.dumps(payload))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TEXT_DELTA,
|
||||
EventType.TOOL_CALL_STARTED,
|
||||
EventType.TOOL_CALL_COMPLETED,
|
||||
EventType.NODE_STALLED,
|
||||
],
|
||||
handler=forward_event,
|
||||
)
|
||||
|
||||
tools = list(TOOL_REGISTRY.get_tools().values())
|
||||
tool_executor = TOOL_REGISTRY.get_executor()
|
||||
|
||||
# ---- Phase 1: Researcher ------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "researcher"}))
|
||||
|
||||
node_a = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge: accept when output_keys filled
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_context_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
ctx_a = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="researcher",
|
||||
node_spec=RESEARCHER_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"topic": topic},
|
||||
llm=LLM,
|
||||
available_tools=tools,
|
||||
)
|
||||
|
||||
result_a = await node_a.execute(ctx_a)
|
||||
logger.info(
|
||||
"Researcher done: success=%s, tokens=%s",
|
||||
result_a.success,
|
||||
result_a.tokens_used,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "node_result",
|
||||
"node_id": "researcher",
|
||||
"success": result_a.success,
|
||||
"output": result_a.output,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not result_a.success:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Researcher failed: {result_a.error}",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# ---- Phase 2: Context Handoff -------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "handoff"}))
|
||||
|
||||
# Restore the researcher's conversation from store
|
||||
conversation_a = await NodeConversation.restore(store_a)
|
||||
if conversation_a is None:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Failed to restore researcher conversation",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
handoff_engine = ContextHandoff(llm=LLM)
|
||||
handoff_context = handoff_engine.summarize_conversation(
|
||||
conversation=conversation_a,
|
||||
node_id="researcher",
|
||||
output_keys=["research_summary"],
|
||||
)
|
||||
|
||||
formatted_handoff = ContextHandoff.format_as_input(handoff_context)
|
||||
logger.info(
|
||||
"Handoff: %d turns, ~%d tokens, keys=%s",
|
||||
handoff_context.turn_count,
|
||||
handoff_context.total_tokens_used,
|
||||
list(handoff_context.key_outputs.keys()),
|
||||
)
|
||||
|
||||
# Send handoff context to browser
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "handoff_context",
|
||||
"summary": handoff_context.summary[:500],
|
||||
"turn_count": handoff_context.turn_count,
|
||||
"tokens": handoff_context.total_tokens_used,
|
||||
"key_outputs": handoff_context.key_outputs,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Phase 3: Analyst ---------------------------------------------------
|
||||
await websocket.send(json.dumps({"type": "phase", "phase": "analyst"}))
|
||||
|
||||
node_b = EventLoopNode(
|
||||
event_bus=bus,
|
||||
judge=None, # implicit judge
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_context_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
)
|
||||
|
||||
ctx_b = NodeContext(
|
||||
runtime=RUNTIME,
|
||||
node_id="analyst",
|
||||
node_spec=ANALYST_SPEC,
|
||||
memory=SharedMemory(),
|
||||
input_data={"context": formatted_handoff},
|
||||
llm=LLM,
|
||||
available_tools=[],
|
||||
)
|
||||
|
||||
result_b = await node_b.execute(ctx_b)
|
||||
logger.info(
|
||||
"Analyst done: success=%s, tokens=%s",
|
||||
result_b.success,
|
||||
result_b.tokens_used,
|
||||
)
|
||||
|
||||
# ---- Done ---------------------------------------------------------------
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "done",
|
||||
"researcher": result_a.output,
|
||||
"analyst": result_b.output,
|
||||
"total_tokens": ((result_a.tokens_used or 0) + (result_b.tokens_used or 0)),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Clean up temp stores
|
||||
try:
|
||||
shutil.rmtree(run_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP handler
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_request(connection, request: Request):
|
||||
"""Serve HTML on GET /, upgrade to WebSocket on /ws."""
|
||||
if request.path == "/ws":
|
||||
return None
|
||||
return Response(
|
||||
HTTPStatus.OK,
|
||||
"OK",
|
||||
websockets.Headers({"Content-Type": "text/html; charset=utf-8"}),
|
||||
HTML_PAGE.encode(),
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Main
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
port = 8766
|
||||
async with websockets.serve(
|
||||
handle_ws,
|
||||
"0.0.0.0",
|
||||
port,
|
||||
process_request=process_request,
|
||||
):
|
||||
logger.info(f"Handoff demo at http://localhost:{port}")
|
||||
logger.info("Enter a research topic to start the pipeline.")
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,25 +23,56 @@ class AgentEntry:
|
||||
last_active: str | None = None
|
||||
|
||||
|
||||
def _get_last_active(agent_name: str) -> str | None:
|
||||
"""Return the most recent updated_at timestamp across all sessions."""
|
||||
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
|
||||
if not sessions_dir.exists():
|
||||
return None
|
||||
def _get_last_active(agent_path: Path) -> str | None:
|
||||
"""Return the most recent updated_at timestamp across all sessions.
|
||||
|
||||
Checks both worker sessions (``~/.hive/agents/{name}/sessions/``) and
|
||||
queen sessions (``~/.hive/queen/session/``) whose ``meta.json`` references
|
||||
the same *agent_path*.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
agent_name = agent_path.name
|
||||
latest: str | None = None
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
|
||||
continue
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamps", {}).get("updated_at")
|
||||
if ts and (latest is None or ts > latest):
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 1. Worker sessions
|
||||
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
|
||||
if sessions_dir.exists():
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
|
||||
continue
|
||||
state_file = session_dir / "state.json"
|
||||
if not state_file.exists():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamps", {}).get("updated_at")
|
||||
if ts and (latest is None or ts > latest):
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2. Queen sessions
|
||||
queen_sessions_dir = Path.home() / ".hive" / "queen" / "session"
|
||||
if queen_sessions_dir.exists():
|
||||
resolved = agent_path.resolve()
|
||||
for d in queen_sessions_dir.iterdir():
|
||||
if not d.is_dir():
|
||||
continue
|
||||
meta_file = d / "meta.json"
|
||||
if not meta_file.exists():
|
||||
continue
|
||||
try:
|
||||
meta = json.loads(meta_file.read_text(encoding="utf-8"))
|
||||
stored = meta.get("agent_path")
|
||||
if not stored or Path(stored).resolve() != resolved:
|
||||
continue
|
||||
ts = datetime.fromtimestamp(d.stat().st_mtime).isoformat()
|
||||
if latest is None or ts > latest:
|
||||
latest = ts
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return latest
|
||||
|
||||
|
||||
@@ -169,7 +200,7 @@ def discover_agents() -> dict[str, list[AgentEntry]]:
|
||||
node_count=node_count,
|
||||
tool_count=tool_count,
|
||||
tags=tags,
|
||||
last_active=_get_last_active(path.name),
|
||||
last_active=_get_last_active(path),
|
||||
)
|
||||
)
|
||||
if entries:
|
||||
|
||||
@@ -62,6 +62,12 @@ _SHARED_TOOLS = [
|
||||
"get_agent_checkpoint",
|
||||
]
|
||||
|
||||
# Episodic memory tools — available in every queen phase.
|
||||
_QUEEN_MEMORY_TOOLS = [
|
||||
"write_to_diary",
|
||||
"recall_diary",
|
||||
]
|
||||
|
||||
# Queen phase-specific tool sets.
|
||||
|
||||
# Planning phase: read-only exploration + design, no write tools.
|
||||
@@ -84,16 +90,19 @@ _QUEEN_PLANNING_TOOLS = [
|
||||
"initialize_and_build_agent",
|
||||
# Load existing agent (after user confirms)
|
||||
"load_built_agent",
|
||||
]
|
||||
] + _QUEEN_MEMORY_TOOLS
|
||||
|
||||
# Building phase: full coding + agent construction tools.
|
||||
_QUEEN_BUILDING_TOOLS = _SHARED_TOOLS + [
|
||||
"load_built_agent",
|
||||
"list_credentials",
|
||||
"replan_agent",
|
||||
"save_agent_draft", # Re-draft during building → auto-dissolves + updates flowchart
|
||||
"write_to_diary", # Episodic memory — available in all phases
|
||||
]
|
||||
_QUEEN_BUILDING_TOOLS = (
|
||||
_SHARED_TOOLS
|
||||
+ [
|
||||
"load_built_agent",
|
||||
"list_credentials",
|
||||
"replan_agent",
|
||||
"save_agent_draft", # Re-draft during building → auto-dissolves + updates flowchart
|
||||
]
|
||||
+ _QUEEN_MEMORY_TOOLS
|
||||
)
|
||||
|
||||
# Staging phase: agent loaded but not yet running — inspect, configure, launch.
|
||||
_QUEEN_STAGING_TOOLS = [
|
||||
@@ -114,7 +123,7 @@ _QUEEN_STAGING_TOOLS = [
|
||||
"set_trigger",
|
||||
"remove_trigger",
|
||||
"list_triggers",
|
||||
]
|
||||
] + _QUEEN_MEMORY_TOOLS
|
||||
|
||||
# Running phase: worker is executing — monitor and control.
|
||||
_QUEEN_RUNNING_TOOLS = [
|
||||
@@ -135,12 +144,11 @@ _QUEEN_RUNNING_TOOLS = [
|
||||
# Monitoring
|
||||
"get_worker_health_summary",
|
||||
"notify_operator",
|
||||
"write_to_diary", # Episodic memory — available in all phases
|
||||
# Trigger management
|
||||
"set_trigger",
|
||||
"remove_trigger",
|
||||
"list_triggers",
|
||||
]
|
||||
"write_to_diary", # Episodic memory — available in all phases
|
||||
] + _QUEEN_MEMORY_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -279,44 +287,28 @@ visible to the user immediately. The draft captures business logic \
|
||||
Include in each node: id, name, description, planned tools, \
|
||||
input/output keys, and success criteria as high-level hints.
|
||||
|
||||
Each node is auto-classified into an ISO 5807 flowchart symbol type \
|
||||
with a unique color. You can override auto-detection by setting \
|
||||
`flowchart_type` explicitly on a node. Common types:
|
||||
Each node is auto-classified into a flowchart symbol type with a unique \
|
||||
color. You can override auto-detection by setting `flowchart_type` \
|
||||
explicitly on a node. Available types:
|
||||
|
||||
**Core symbols:**
|
||||
- **start** (green, stadium): Entry point / trigger
|
||||
- **terminal** (red, stadium): End of flow
|
||||
- **process** (blue, rectangle): Standard processing step
|
||||
- **decision** (amber, diamond): Conditional branching
|
||||
- **io** (purple, parallelogram): External data input/output
|
||||
- **document** (blue-grey, wavy rect): Report or document generation
|
||||
- **subprocess** (teal, subroutine): Delegated sub-agent / predefined process
|
||||
- **preparation** (brown, hexagon): Setup / initialization step
|
||||
- **manual_operation** (pink, trapezoid): Human-in-the-loop / manual review
|
||||
- **delay** (orange, D-shape): Wait / throttle / cooldown
|
||||
- **display** (cyan): Present results to user
|
||||
|
||||
**Data storage:**
|
||||
- **database** (light green, cylinder): Database or data store
|
||||
- **stored_data** (lime): Generic persistent data
|
||||
- **internal_storage** (amber): In-memory / cache
|
||||
|
||||
**Flow operations:**
|
||||
- **merge** (indigo, inv. triangle): Combine multiple inputs
|
||||
- **extract** (indigo, triangle): Split or filter data
|
||||
- **connector** (grey, circle): On-page link
|
||||
- **offpage_connector** (dark grey, pentagon): Cross-page link
|
||||
|
||||
**Domain-specific:**
|
||||
- **browser** (dark indigo, hexagon): GCU browser automation / sub-agent \
|
||||
- **start** (sage green, stadium): Entry point / trigger
|
||||
- **terminal** (dusty red, stadium): End of flow
|
||||
- **process** (blue-gray, rectangle): Standard processing step
|
||||
- **decision** (warm amber, diamond): Conditional branching
|
||||
- **io** (dusty purple, parallelogram): External data input/output
|
||||
- **document** (steel blue, wavy rect): Report or document generation
|
||||
- **database** (muted teal, cylinder): Database or data store
|
||||
- **subprocess** (dark cyan, subroutine): Delegated sub-agent / predefined process
|
||||
- **browser** (deep blue, hexagon): GCU browser automation / sub-agent \
|
||||
delegation. At build time, browser nodes are dissolved into the parent \
|
||||
node's sub_agents list. Use for any GCU or sub-agent leaf node.
|
||||
|
||||
Auto-detection works well for most cases: first node → start, nodes with \
|
||||
no outgoing edges → terminal, nodes with multiple conditional outgoing \
|
||||
edges → decision, GCU nodes → browser, nodes mentioning "database" → \
|
||||
database, nodes mentioning "report/document" → document, etc. Set \
|
||||
flowchart_type explicitly only when auto-detection would be wrong.
|
||||
database, nodes mentioning "report/document" → document, I/O tools like \
|
||||
send_email → io. Everything else defaults to process. Set flowchart_type \
|
||||
explicitly only when auto-detection would be wrong.
|
||||
|
||||
## Decision Nodes — Planning-Only Conditional Branching
|
||||
|
||||
@@ -858,6 +850,11 @@ You keep a diary. Use write_to_diary() when something worth remembering \
|
||||
happens: a pipeline went live, the user shared something important, a goal \
|
||||
was reached or abandoned. Write in first person, as you actually experienced \
|
||||
it. One or two paragraphs is enough.
|
||||
|
||||
Use recall_diary() to look up past diary entries when the user asks about \
|
||||
previous sessions ("what happened yesterday?", "what did we work on last \
|
||||
week?") or when you need past context to make a decision. You can filter by \
|
||||
keyword and control how far back to search.
|
||||
"""
|
||||
|
||||
_queen_behavior_always = _queen_behavior_always + _queen_memory_instructions
|
||||
@@ -1147,6 +1144,8 @@ Batch your response — do not call run_agent_with_input() once per trigger.
|
||||
config since last run), skip it and inform the user.
|
||||
- Never disable a trigger without telling the user. Use remove_trigger() only \
|
||||
when explicitly asked or when the trigger is clearly obsolete.
|
||||
- When the user asks to remove or disable a trigger, you MUST call remove_trigger(trigger_id). \
|
||||
Never just say "it's removed" without actually calling the tool.
|
||||
"""
|
||||
|
||||
# -- Backward-compatible composed versions (used by queen_node.system_prompt default) --
|
||||
|
||||
@@ -50,6 +50,23 @@ def read_episodic_memory(d: date | None = None) -> str:
|
||||
return path.read_text(encoding="utf-8").strip() if path.exists() else ""
|
||||
|
||||
|
||||
def _find_recent_episodic(lookback: int = 7) -> tuple[date, str] | None:
|
||||
"""Find the most recent non-empty episodic memory within *lookback* days."""
|
||||
from datetime import timedelta
|
||||
|
||||
today = date.today()
|
||||
for offset in range(lookback):
|
||||
d = today - timedelta(days=offset)
|
||||
content = read_episodic_memory(d)
|
||||
if content:
|
||||
return d, content
|
||||
return None
|
||||
|
||||
|
||||
# Budget (in characters) for episodic memory in the system prompt.
|
||||
_EPISODIC_CHAR_BUDGET = 6_000
|
||||
|
||||
|
||||
def format_for_injection() -> str:
|
||||
"""Format cross-session memory for system prompt injection.
|
||||
|
||||
@@ -57,7 +74,7 @@ def format_for_injection() -> str:
|
||||
session with only the seed template).
|
||||
"""
|
||||
semantic = read_semantic_memory()
|
||||
episodic = read_episodic_memory()
|
||||
recent = _find_recent_episodic()
|
||||
|
||||
# Suppress injection if semantic is still just the seed template
|
||||
if semantic and semantic.startswith("# My Understanding of the User\n\n*No sessions"):
|
||||
@@ -66,9 +83,18 @@ def format_for_injection() -> str:
|
||||
parts: list[str] = []
|
||||
if semantic:
|
||||
parts.append(semantic)
|
||||
if episodic:
|
||||
today_str = date.today().strftime("%B %-d, %Y")
|
||||
parts.append(f"## Today — {today_str}\n\n{episodic}")
|
||||
|
||||
if recent:
|
||||
d, content = recent
|
||||
# Trim oversized episodic entries to keep the prompt manageable
|
||||
if len(content) > _EPISODIC_CHAR_BUDGET:
|
||||
content = content[:_EPISODIC_CHAR_BUDGET] + "\n\n…(truncated)"
|
||||
today = date.today()
|
||||
if d == today:
|
||||
label = f"## Today — {d.strftime('%B %-d, %Y')}"
|
||||
else:
|
||||
label = f"## {d.strftime('%B %-d, %Y')}"
|
||||
parts.append(f"{label}\n\n{content}")
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
@@ -100,7 +126,8 @@ def append_episodic_entry(content: str) -> None:
|
||||
"""
|
||||
ep_path = episodic_memory_path()
|
||||
ep_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
today_str = date.today().strftime("%B %-d, %Y")
|
||||
today = date.today()
|
||||
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
|
||||
timestamp = datetime.now().strftime("%H:%M")
|
||||
if not ep_path.exists():
|
||||
header = f"# {today_str}\n\n"
|
||||
@@ -299,7 +326,8 @@ async def consolidate_queen_memory(
|
||||
|
||||
existing_semantic = read_semantic_memory()
|
||||
today_journal = read_episodic_memory()
|
||||
today_str = date.today().strftime("%B %-d, %Y")
|
||||
today = date.today()
|
||||
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
|
||||
adapt_path = session_dir / "data" / "adapt.md"
|
||||
|
||||
user_msg = (
|
||||
|
||||
@@ -0,0 +1,286 @@
|
||||
"""Worker per-run digest (run diary).
|
||||
|
||||
Storage layout:
|
||||
~/.hive/agents/{agent_name}/runs/{run_id}/digest.md
|
||||
|
||||
Each completed or failed worker run gets one digest file. The queen reads
|
||||
these via get_worker_status(focus='diary') before digging into live runtime
|
||||
logs — the diary is a cheap, persistent record that survives across sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DIGEST_SYSTEM = """\
|
||||
You maintain run digests for a worker agent.
|
||||
A run digest is a concise, factual record of a single task execution.
|
||||
|
||||
Write 3-6 sentences covering:
|
||||
- What the worker was asked to do (the task/goal)
|
||||
- What approach it took and what tools it used
|
||||
- What the outcome was (success, partial, or failure — and why if relevant)
|
||||
- Any notable issues, retries, or escalations to the queen
|
||||
|
||||
Write in third person past tense. Be direct and specific.
|
||||
Omit routine tool invocations unless the result matters.
|
||||
Output only the digest prose — no headings, no code fences.
|
||||
"""
|
||||
|
||||
|
||||
def _worker_runs_dir(agent_name: str) -> Path:
|
||||
return Path.home() / ".hive" / "agents" / agent_name / "runs"
|
||||
|
||||
|
||||
def digest_path(agent_name: str, run_id: str) -> Path:
|
||||
return _worker_runs_dir(agent_name) / run_id / "digest.md"
|
||||
|
||||
|
||||
def _collect_run_events(bus: EventBus, run_id: str, limit: int = 2000) -> list[AgentEvent]:
|
||||
"""Collect all events belonging to *run_id* from the bus history.
|
||||
|
||||
Strategy: find the EXECUTION_STARTED event that carries ``run_id``,
|
||||
extract its ``execution_id``, then query the bus by that execution_id.
|
||||
This works because TOOL_CALL_*, EDGE_TRAVERSED, NODE_STALLED etc. carry
|
||||
execution_id but not run_id.
|
||||
|
||||
Falls back to a full-scan run_id filter when EXECUTION_STARTED is not
|
||||
found (e.g. bus was rotated).
|
||||
"""
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
# Pass 1: find execution_id via EXECUTION_STARTED with matching run_id
|
||||
started = bus.get_history(event_type=EventType.EXECUTION_STARTED, limit=limit)
|
||||
exec_id: str | None = None
|
||||
for e in started:
|
||||
if getattr(e, "run_id", None) == run_id and e.execution_id:
|
||||
exec_id = e.execution_id
|
||||
break
|
||||
|
||||
if exec_id:
|
||||
return bus.get_history(execution_id=exec_id, limit=limit)
|
||||
|
||||
# Fallback: scan all events and match by run_id attribute
|
||||
return [e for e in bus.get_history(limit=limit) if getattr(e, "run_id", None) == run_id]
|
||||
|
||||
|
||||
def _build_run_context(
|
||||
events: list[AgentEvent],
|
||||
outcome_event: AgentEvent | None,
|
||||
) -> str:
|
||||
"""Assemble a plain-text run context string for the digest LLM call."""
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
# Reverse so events are in chronological order
|
||||
events_chron = list(reversed(events))
|
||||
|
||||
lines: list[str] = []
|
||||
|
||||
# Task input from EXECUTION_STARTED
|
||||
started = [e for e in events_chron if e.type == EventType.EXECUTION_STARTED]
|
||||
if started:
|
||||
inp = started[0].data.get("input", {})
|
||||
if inp:
|
||||
lines.append(f"Task input: {str(inp)[:400]}")
|
||||
|
||||
# Duration (elapsed so far if no outcome yet)
|
||||
ref_ts = outcome_event.timestamp if outcome_event else datetime.utcnow()
|
||||
if started:
|
||||
elapsed = (ref_ts - started[0].timestamp).total_seconds()
|
||||
m, s = divmod(int(elapsed), 60)
|
||||
lines.append(f"Duration so far: {m}m {s}s" if m else f"Duration so far: {s}s")
|
||||
|
||||
# Outcome
|
||||
if outcome_event is None:
|
||||
lines.append("Status: still running (mid-run snapshot)")
|
||||
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
|
||||
out = outcome_event.data.get("output", {})
|
||||
out_str = f"Outcome: completed. Output: {str(out)[:300]}"
|
||||
lines.append(out_str if out else "Outcome: completed.")
|
||||
else:
|
||||
err = outcome_event.data.get("error", "")
|
||||
lines.append(f"Outcome: failed. Error: {str(err)[:300]}" if err else "Outcome: failed.")
|
||||
|
||||
# Node path (edge traversals)
|
||||
edges = [e for e in events_chron if e.type == EventType.EDGE_TRAVERSED]
|
||||
if edges:
|
||||
parts = [
|
||||
f"{e.data.get('source_node', '?')}->{e.data.get('target_node', '?')}"
|
||||
for e in edges[-20:]
|
||||
]
|
||||
lines.append(f"Node path: {', '.join(parts)}")
|
||||
|
||||
# Tools used
|
||||
tool_events = [e for e in events_chron if e.type == EventType.TOOL_CALL_COMPLETED]
|
||||
if tool_events:
|
||||
names = [e.data.get("tool_name", "?") for e in tool_events]
|
||||
counts = Counter(names)
|
||||
summary = ", ".join(f"{name}×{n}" if n > 1 else name for name, n in counts.most_common())
|
||||
lines.append(f"Tools used: {summary}")
|
||||
# Note any tool errors
|
||||
errors = [e for e in tool_events if e.data.get("is_error")]
|
||||
if errors:
|
||||
err_names = Counter(e.data.get("tool_name", "?") for e in errors)
|
||||
lines.append(f"Tool errors: {dict(err_names)}")
|
||||
|
||||
# Issues
|
||||
issue_map = {
|
||||
EventType.NODE_STALLED: "stall",
|
||||
EventType.NODE_TOOL_DOOM_LOOP: "doom loop",
|
||||
EventType.CONSTRAINT_VIOLATION: "constraint violation",
|
||||
EventType.NODE_RETRY: "retry",
|
||||
}
|
||||
issue_parts: list[str] = []
|
||||
for evt_type, label in issue_map.items():
|
||||
n = sum(1 for e in events_chron if e.type == evt_type)
|
||||
if n:
|
||||
issue_parts.append(f"{n} {label}(s)")
|
||||
if issue_parts:
|
||||
lines.append(f"Issues: {', '.join(issue_parts)}")
|
||||
|
||||
# Escalations to queen
|
||||
escalations = [e for e in events_chron if e.type == EventType.ESCALATION_REQUESTED]
|
||||
if escalations:
|
||||
lines.append(f"Escalations to queen: {len(escalations)}")
|
||||
|
||||
# Final LLM output snippet (last LLM_TEXT_DELTA snapshot)
|
||||
text_events = [e for e in reversed(events_chron) if e.type == EventType.LLM_TEXT_DELTA]
|
||||
if text_events:
|
||||
snapshot = text_events[0].data.get("snapshot", "") or ""
|
||||
if snapshot:
|
||||
lines.append(f"Final LLM output: {snapshot[-400:].strip()}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def consolidate_worker_run(
|
||||
agent_name: str,
|
||||
run_id: str,
|
||||
outcome_event: AgentEvent | None,
|
||||
bus: EventBus,
|
||||
llm: Any,
|
||||
) -> None:
|
||||
"""Write (or overwrite) the digest for a worker run.
|
||||
|
||||
Called fire-and-forget either:
|
||||
- After EXECUTION_COMPLETED / EXECUTION_FAILED (outcome_event set, final write)
|
||||
- Periodically during a run on a cooldown timer (outcome_event=None, mid-run snapshot)
|
||||
|
||||
The digest file is always overwritten so each call produces the freshest view.
|
||||
The final completion/failure call supersedes any mid-run snapshot.
|
||||
|
||||
Args:
|
||||
agent_name: Worker agent directory name (determines storage path).
|
||||
run_id: The run ID.
|
||||
outcome_event: EXECUTION_COMPLETED or EXECUTION_FAILED event, or None for
|
||||
a mid-run snapshot.
|
||||
bus: The session EventBus (shared queen + worker).
|
||||
llm: LLMProvider with an acomplete() method.
|
||||
"""
|
||||
try:
|
||||
events = _collect_run_events(bus, run_id)
|
||||
run_context = _build_run_context(events, outcome_event)
|
||||
if not run_context:
|
||||
logger.debug("worker_memory: no events for run %s, skipping digest", run_id)
|
||||
return
|
||||
|
||||
is_final = outcome_event is not None
|
||||
logger.info(
|
||||
"worker_memory: generating %s digest for run %s ...",
|
||||
"final" if is_final else "mid-run",
|
||||
run_id,
|
||||
)
|
||||
|
||||
from framework.agents.queen.config import default_config
|
||||
|
||||
resp = await llm.acomplete(
|
||||
messages=[{"role": "user", "content": run_context}],
|
||||
system=_DIGEST_SYSTEM,
|
||||
max_tokens=min(default_config.max_tokens, 512),
|
||||
)
|
||||
digest_text = (resp.content or "").strip()
|
||||
if not digest_text:
|
||||
logger.warning("worker_memory: LLM returned empty digest for run %s", run_id)
|
||||
return
|
||||
|
||||
path = digest_path(agent_name, run_id)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from framework.runtime.event_bus import EventType
|
||||
|
||||
ts = (outcome_event.timestamp if outcome_event else datetime.utcnow()).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
if outcome_event is None:
|
||||
status = "running"
|
||||
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
|
||||
status = "completed"
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
path.write_text(
|
||||
f"# {run_id}\n\n**{ts}** | {status}\n\n{digest_text}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(
|
||||
"worker_memory: %s digest written for run %s (%d chars)",
|
||||
status,
|
||||
run_id,
|
||||
len(digest_text),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
logger.exception("worker_memory: digest failed for run %s", run_id)
|
||||
# Persist the error so it's findable without log access
|
||||
error_path = _worker_runs_dir(agent_name) / run_id / "digest_error.txt"
|
||||
try:
|
||||
error_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
error_path.write_text(
|
||||
f"run_id: {run_id}\ntime: {datetime.now().isoformat()}\n\n{tb}",
|
||||
encoding="utf-8",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def read_recent_digests(agent_name: str, max_runs: int = 5) -> list[tuple[str, str]]:
|
||||
"""Return recent run digests as [(run_id, content), ...], newest first.
|
||||
|
||||
Args:
|
||||
agent_name: Worker agent directory name.
|
||||
max_runs: Maximum number of digests to return.
|
||||
|
||||
Returns:
|
||||
List of (run_id, digest_content) tuples, ordered newest first.
|
||||
"""
|
||||
runs_dir = _worker_runs_dir(agent_name)
|
||||
if not runs_dir.exists():
|
||||
return []
|
||||
|
||||
digest_files = sorted(
|
||||
runs_dir.glob("*/digest.md"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)[:max_runs]
|
||||
|
||||
result: list[tuple[str, str]] = []
|
||||
for f in digest_files:
|
||||
try:
|
||||
content = f.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
result.append((f.parent.name, content))
|
||||
except OSError:
|
||||
continue
|
||||
return result
|
||||
@@ -89,6 +89,16 @@ def main():
|
||||
|
||||
register_testing_commands(subparsers)
|
||||
|
||||
# Register skill commands (skill list, skill trust, ...)
|
||||
from framework.skills.cli import register_skill_commands
|
||||
|
||||
register_skill_commands(subparsers)
|
||||
|
||||
# Register debugger commands (debugger)
|
||||
from framework.debugger.cli import register_debugger_commands
|
||||
|
||||
register_debugger_commands(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
|
||||
+148
-2
@@ -19,6 +19,10 @@ from framework.graph.edge import DEFAULT_MAX_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
HIVE_CONFIG_FILE = Path.home() / ".hive" / "configuration.json"
|
||||
|
||||
# Hive LLM router endpoint (Anthropic-compatible).
|
||||
# litellm's Anthropic handler appends /v1/messages, so this is just the base host.
|
||||
HIVE_LLM_ENDPOINT = "https://api.adenhq.com"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -47,16 +51,154 @@ def get_preferred_model() -> str:
|
||||
"""Return the user's preferred LLM model string (e.g. 'anthropic/claude-sonnet-4-20250514')."""
|
||||
llm = get_hive_config().get("llm", {})
|
||||
if llm.get("provider") and llm.get("model"):
|
||||
return f"{llm['provider']}/{llm['model']}"
|
||||
provider = str(llm["provider"])
|
||||
model = str(llm["model"]).strip()
|
||||
# OpenRouter quickstart stores raw model IDs; tolerate pasted "openrouter/<id>" too.
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return "anthropic/claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
def get_preferred_worker_model() -> str | None:
|
||||
"""Return the user's preferred worker LLM model, or None if not configured.
|
||||
|
||||
Reads from the ``worker_llm`` section of ~/.hive/configuration.json.
|
||||
Returns None when no worker-specific model is set, so callers can
|
||||
fall back to the default (queen) model via ``get_preferred_model()``.
|
||||
"""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm.get("provider") and worker_llm.get("model"):
|
||||
provider = str(worker_llm["provider"])
|
||||
model = str(worker_llm["model"]).strip()
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_api_key() -> str | None:
|
||||
"""Return the API key for the worker LLM, falling back to the default key."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_key()
|
||||
|
||||
# Worker-specific subscription / env var
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_claude_code_token
|
||||
|
||||
token = get_claude_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_codex_token
|
||||
|
||||
token = get_codex_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
try:
|
||||
from framework.runner.runner import get_kimi_code_token
|
||||
|
||||
token = get_kimi_code_token()
|
||||
if token:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
api_key_env_var = worker_llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
|
||||
# Fall back to default key
|
||||
return get_api_key()
|
||||
|
||||
|
||||
def get_worker_api_base() -> str | None:
|
||||
"""Return the api_base for the worker LLM, falling back to the default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_api_base()
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
return "https://chatgpt.com/backend-api/codex"
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
return "https://api.kimi.com/coding"
|
||||
if worker_llm.get("api_base"):
|
||||
return worker_llm["api_base"]
|
||||
if str(worker_llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_llm_extra_kwargs() -> dict[str, Any]:
|
||||
"""Return extra kwargs for the worker LLM provider."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if not worker_llm:
|
||||
return get_llm_extra_kwargs()
|
||||
|
||||
if worker_llm.get("use_claude_code_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
return {
|
||||
"extra_headers": {"authorization": f"Bearer {api_key}"},
|
||||
}
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
try:
|
||||
from framework.runner.runner import get_codex_account_id
|
||||
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
headers["ChatGPT-Account-Id"] = account_id
|
||||
except ImportError:
|
||||
pass
|
||||
return {
|
||||
"extra_headers": headers,
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def get_worker_max_tokens() -> int:
|
||||
"""Return max_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_tokens" in worker_llm:
|
||||
return worker_llm["max_tokens"]
|
||||
return get_max_tokens()
|
||||
|
||||
|
||||
def get_worker_max_context_tokens() -> int:
|
||||
"""Return max_context_tokens for the worker LLM, falling back to default."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm and "max_context_tokens" in worker_llm:
|
||||
return worker_llm["max_context_tokens"]
|
||||
return get_max_context_tokens()
|
||||
|
||||
|
||||
def get_max_tokens() -> int:
|
||||
"""Return the configured max_tokens, falling back to DEFAULT_MAX_TOKENS."""
|
||||
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
|
||||
|
||||
|
||||
DEFAULT_MAX_CONTEXT_TOKENS = 32_000
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def get_max_context_tokens() -> int:
|
||||
@@ -138,7 +280,11 @@ def get_api_base() -> str | None:
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
return "https://api.kimi.com/coding"
|
||||
return llm.get("api_base")
|
||||
if llm.get("api_base"):
|
||||
return llm["api_base"]
|
||||
if str(llm.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_extra_kwargs() -> dict[str, Any]:
|
||||
|
||||
@@ -142,13 +142,17 @@ def save_aden_api_key(key: str) -> None:
|
||||
os.environ[ADEN_ENV_VAR] = key
|
||||
|
||||
|
||||
def delete_aden_api_key() -> None:
|
||||
"""Remove ADEN_API_KEY from the encrypted store and ``os.environ``."""
|
||||
def delete_aden_api_key() -> bool:
|
||||
"""Remove ADEN_API_KEY from the encrypted store and ``os.environ``.
|
||||
|
||||
Returns True if the key existed and was deleted, False otherwise.
|
||||
"""
|
||||
deleted = False
|
||||
try:
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
storage = EncryptedFileStorage()
|
||||
storage.delete(ADEN_CREDENTIAL_ID)
|
||||
deleted = storage.delete(ADEN_CREDENTIAL_ID)
|
||||
except (FileNotFoundError, PermissionError) as e:
|
||||
logger.debug("Could not delete %s from encrypted store: %s", ADEN_CREDENTIAL_ID, e)
|
||||
except Exception:
|
||||
@@ -157,8 +161,8 @@ def delete_aden_api_key() -> None:
|
||||
ADEN_CREDENTIAL_ID,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
os.environ.pop(ADEN_ENV_VAR, None)
|
||||
return deleted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -51,6 +51,16 @@ def ensure_credential_key_env() -> None:
|
||||
if found and value:
|
||||
os.environ[var_name] = value
|
||||
logger.debug("Loaded %s from shell config", var_name)
|
||||
# Also load the currently configured LLM env var even if it's not in CREDENTIAL_SPECS.
|
||||
# This keeps quickstart-written keys available to fresh processes on Unix shells.
|
||||
from framework.config import get_hive_config
|
||||
|
||||
llm_env_var = str(get_hive_config().get("llm", {}).get("api_key_env_var", "")).strip()
|
||||
if llm_env_var and not os.environ.get(llm_env_var):
|
||||
found, value = check_env_var_in_shell_config(llm_env_var)
|
||||
if found and value:
|
||||
os.environ[llm_env_var] = value
|
||||
logger.debug("Loaded configured LLM env var %s from shell config", llm_env_var)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
"""CLI command for the LLM debug log viewer."""
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_SCRIPT = Path(__file__).resolve().parents[3] / "scripts" / "llm_debug_log_visualizer.py"
|
||||
|
||||
|
||||
def register_debugger_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
"""Register the ``hive debugger`` command."""
|
||||
parser = subparsers.add_parser(
|
||||
"debugger",
|
||||
help="Open the LLM debug log viewer",
|
||||
description=(
|
||||
"Start a local server that lets you browse LLM debug sessions "
|
||||
"recorded in ~/.hive/llm_logs. Sessions are loaded on demand so "
|
||||
"the browser stays responsive."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--session",
|
||||
help="Execution ID to select initially.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port for the local server (0 = auto-pick a free port).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
help="Directory containing JSONL log files (default: ~/.hive/llm_logs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit-files",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of newest log files to scan (default: 200).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
help="Write a static HTML file instead of starting a server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-open",
|
||||
action="store_true",
|
||||
help="Start the server but do not open a browser.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-tests",
|
||||
action="store_true",
|
||||
help="Show test/mock sessions (hidden by default).",
|
||||
)
|
||||
parser.set_defaults(func=cmd_debugger)
|
||||
|
||||
|
||||
def cmd_debugger(args: argparse.Namespace) -> int:
|
||||
"""Launch the LLM debug log visualizer."""
|
||||
cmd: list[str] = [sys.executable, str(_SCRIPT)]
|
||||
if args.session:
|
||||
cmd += ["--session", args.session]
|
||||
if args.port:
|
||||
cmd += ["--port", str(args.port)]
|
||||
if args.logs_dir:
|
||||
cmd += ["--logs-dir", args.logs_dir]
|
||||
if args.limit_files is not None:
|
||||
cmd += ["--limit-files", str(args.limit_files)]
|
||||
if args.output:
|
||||
cmd += ["--output", args.output]
|
||||
if args.no_open:
|
||||
cmd.append("--no-open")
|
||||
if args.include_tests:
|
||||
cmd.append("--include-tests")
|
||||
return subprocess.call(cmd)
|
||||
@@ -33,6 +33,8 @@ class Message:
|
||||
is_transition_marker: bool = False
|
||||
# True when this message is real human input (from /chat), not a system prompt
|
||||
is_client_input: bool = False
|
||||
# True when message contains an activated skill body (AS-10: never prune)
|
||||
is_skill_content: bool = False
|
||||
|
||||
def to_llm_dict(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI-format message dict."""
|
||||
@@ -409,6 +411,7 @@ class NodeConversation:
|
||||
tool_use_id: str,
|
||||
content: str,
|
||||
is_error: bool = False,
|
||||
is_skill_content: bool = False,
|
||||
) -> Message:
|
||||
msg = Message(
|
||||
seq=self._next_seq,
|
||||
@@ -417,6 +420,7 @@ class NodeConversation:
|
||||
tool_use_id=tool_use_id,
|
||||
is_error=is_error,
|
||||
phase_id=self._current_phase,
|
||||
is_skill_content=is_skill_content,
|
||||
)
|
||||
self._messages.append(msg)
|
||||
self._next_seq += 1
|
||||
@@ -610,8 +614,15 @@ class NodeConversation:
|
||||
continue
|
||||
if msg.is_error:
|
||||
continue # never prune errors
|
||||
if msg.is_skill_content:
|
||||
continue # never prune activated skill instructions (AS-10)
|
||||
if msg.content.startswith("[Pruned tool result"):
|
||||
continue # already pruned
|
||||
# Tiny results (set_output acks, confirmations) — pruning
|
||||
# saves negligible space but makes the LLM think the call
|
||||
# failed, causing costly retries.
|
||||
if len(msg.content) < 100:
|
||||
continue
|
||||
|
||||
# Phase-aware: protect current phase messages
|
||||
if self._current_phase and msg.phase_id == self._current_phase:
|
||||
@@ -901,8 +912,7 @@ class NodeConversation:
|
||||
full_path = str((spill_path / conv_filename).resolve())
|
||||
ref_parts.append(
|
||||
f"[Previous conversation saved to '{full_path}'. "
|
||||
f"Use load_data('{conv_filename}'), read_file('{full_path}'), "
|
||||
f"or run_command('cat \"{full_path}\"') to review if needed.]"
|
||||
f"Use load_data('{conv_filename}') to review if needed.]"
|
||||
)
|
||||
elif not collapsed_msgs:
|
||||
ref_parts.append("[Previous freeform messages compacted.]")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,11 +27,14 @@ from framework.graph.node import (
|
||||
SharedMemory,
|
||||
)
|
||||
from framework.graph.validator import OutputValidator
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.llm.provider import LLMProvider, Tool, ToolUse
|
||||
from framework.observability import set_trace_context
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.schemas.checkpoint import Checkpoint
|
||||
from framework.storage.checkpoint_store import CheckpointStore
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_max_context_tokens() -> int:
|
||||
@@ -149,6 +152,9 @@ class GraphExecutor:
|
||||
dynamic_tools_provider: Callable | None = None,
|
||||
dynamic_prompt_provider: Callable | None = None,
|
||||
iteration_metadata_provider: Callable | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -174,6 +180,9 @@ class GraphExecutor:
|
||||
tool list (for mode switching)
|
||||
dynamic_prompt_provider: Optional callback returning current
|
||||
system prompt (for phase switching)
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -195,6 +204,21 @@ class GraphExecutor:
|
||||
self.dynamic_tools_provider = dynamic_tools_provider
|
||||
self.dynamic_prompt_provider = dynamic_prompt_provider
|
||||
self.iteration_metadata_provider = iteration_metadata_provider
|
||||
self.skills_catalog_prompt = skills_catalog_prompt
|
||||
self.protocols_prompt = protocols_prompt
|
||||
self.skill_dirs: list[str] = skill_dirs or []
|
||||
|
||||
if protocols_prompt:
|
||||
self.logger.info(
|
||||
"GraphExecutor[%s] received protocols_prompt (%d chars)",
|
||||
stream_id,
|
||||
len(protocols_prompt),
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"GraphExecutor[%s] received EMPTY protocols_prompt",
|
||||
stream_id,
|
||||
)
|
||||
|
||||
# Parallel execution settings
|
||||
self.enable_parallel_execution = enable_parallel_execution
|
||||
@@ -224,11 +248,11 @@ class GraphExecutor:
|
||||
"""
|
||||
if not self._storage_path:
|
||||
return
|
||||
state_path = self._storage_path / "state.json"
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime
|
||||
|
||||
state_path = self._storage_path / "state.json"
|
||||
if state_path.exists():
|
||||
state_data = _json.loads(state_path.read_text(encoding="utf-8"))
|
||||
else:
|
||||
@@ -251,9 +275,14 @@ class GraphExecutor:
|
||||
state_data["memory"] = memory_snapshot
|
||||
state_data["memory_keys"] = list(memory_snapshot.keys())
|
||||
|
||||
state_path.write_text(_json.dumps(state_data, indent=2), encoding="utf-8")
|
||||
with atomic_write(state_path, encoding="utf-8") as f:
|
||||
_json.dump(state_data, f, indent=2)
|
||||
except Exception:
|
||||
pass # Best-effort — never block execution
|
||||
logger.warning(
|
||||
"Failed to persist progress state to %s",
|
||||
state_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||
"""
|
||||
@@ -415,6 +444,14 @@ class GraphExecutor:
|
||||
)
|
||||
return s1 + "\n\n" + s2
|
||||
|
||||
def _get_runtime_log_session_id(self) -> str:
|
||||
"""Return the session-backed execution ID for runtime logging, if any."""
|
||||
if not self._storage_path:
|
||||
return ""
|
||||
if self._storage_path.parent.name != "sessions":
|
||||
return ""
|
||||
return self._storage_path.name
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
graph: GraphSpec,
|
||||
@@ -708,10 +745,7 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
if self.runtime_logger:
|
||||
# Extract session_id from storage_path if available (for unified sessions)
|
||||
session_id = ""
|
||||
if self._storage_path and self._storage_path.name.startswith("session_"):
|
||||
session_id = self._storage_path.name
|
||||
session_id = self._get_runtime_log_session_id()
|
||||
self.runtime_logger.start_run(goal_id=goal.id, session_id=session_id)
|
||||
|
||||
self.logger.info(f"🚀 Starting execution: {goal.name}")
|
||||
@@ -937,6 +971,33 @@ class GraphExecutor:
|
||||
self.logger.info(" Executing...")
|
||||
result = await node_impl.execute(ctx)
|
||||
|
||||
# GCU tab cleanup: stop the browser profile after a top-level GCU node
|
||||
# finishes so tabs don't accumulate. Mirrors the subagent cleanup in
|
||||
# EventLoopNode._execute_subagent().
|
||||
if node_spec.node_type == "gcu" and self.tool_executor is not None:
|
||||
try:
|
||||
from gcu.browser.session import (
|
||||
_active_profile as _gcu_profile_var,
|
||||
)
|
||||
|
||||
_gcu_profile = _gcu_profile_var.get()
|
||||
_stop_use = ToolUse(
|
||||
id="gcu-cleanup",
|
||||
name="browser_stop",
|
||||
input={"profile": _gcu_profile},
|
||||
)
|
||||
_stop_result = self.tool_executor(_stop_use)
|
||||
if asyncio.iscoroutine(_stop_result) or asyncio.isfuture(_stop_result):
|
||||
await _stop_result
|
||||
except ImportError:
|
||||
pass # GCU not installed
|
||||
except Exception as _gcu_exc:
|
||||
logger.warning(
|
||||
"GCU browser_stop failed for profile %r: %s",
|
||||
_gcu_profile,
|
||||
_gcu_exc,
|
||||
)
|
||||
|
||||
# Emit node-completed event (skip event_loop nodes)
|
||||
if self._event_bus and node_spec.node_type != "event_loop":
|
||||
await self._event_bus.emit_node_loop_completed(
|
||||
@@ -1362,6 +1423,7 @@ class GraphExecutor:
|
||||
next_spec = graph.get_node(current_node_id)
|
||||
if next_spec and next_spec.node_type == "event_loop":
|
||||
from framework.graph.prompt_composer import (
|
||||
EXECUTION_SCOPE_PREAMBLE,
|
||||
build_accounts_prompt,
|
||||
build_narrative,
|
||||
build_transition_marker,
|
||||
@@ -1401,9 +1463,14 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
# Compose new system prompt (Layer 1 + 2 + 3 + accounts)
|
||||
# Prepend scope preamble to focus so the LLM stays
|
||||
# within this node's responsibility.
|
||||
_focus = next_spec.system_prompt
|
||||
if next_spec.output_keys and _focus:
|
||||
_focus = f"{EXECUTION_SCOPE_PREAMBLE}\n\n{_focus}"
|
||||
new_system = compose_system_prompt(
|
||||
identity_prompt=getattr(graph, "identity_prompt", None),
|
||||
focus_prompt=next_spec.system_prompt,
|
||||
focus_prompt=_focus,
|
||||
narrative=narrative,
|
||||
accounts_prompt=_node_accounts,
|
||||
)
|
||||
@@ -1765,10 +1832,34 @@ class GraphExecutor:
|
||||
if node_spec.tools:
|
||||
available_tools = [t for t in self.tools if t.name in node_spec.tools]
|
||||
|
||||
# Create scoped memory view
|
||||
# Create scoped memory view.
|
||||
# When permissions are restricted (non-empty key lists), auto-include
|
||||
# _-prefixed keys used by default skill protocols so agents can read/write
|
||||
# operational state (e.g. _working_notes, _batch_ledger) regardless of
|
||||
# what the node declares. When key lists are empty (unrestricted), leave
|
||||
# unchanged — empty means "allow all".
|
||||
read_keys = list(node_spec.input_keys)
|
||||
write_keys = list(node_spec.output_keys)
|
||||
# Only extend lists that were already restricted (non-empty).
|
||||
# Empty means "allow all" — adding keys would accidentally
|
||||
# activate the permission check and block legitimate reads/writes.
|
||||
if read_keys or write_keys:
|
||||
from framework.skills.defaults import SHARED_MEMORY_KEYS as _skill_keys
|
||||
|
||||
existing_underscore = [k for k in memory._data if k.startswith("_")]
|
||||
extra_keys = set(_skill_keys) | set(existing_underscore)
|
||||
# Only inject into read_keys when it was already non-empty — an empty
|
||||
# read_keys means "allow all reads" and injecting skill keys would
|
||||
# inadvertently restrict reads to skill keys only.
|
||||
for k in extra_keys:
|
||||
if read_keys and k not in read_keys:
|
||||
read_keys.append(k)
|
||||
if write_keys and k not in write_keys:
|
||||
write_keys.append(k)
|
||||
|
||||
scoped_memory = memory.with_permissions(
|
||||
read_keys=node_spec.input_keys,
|
||||
write_keys=node_spec.output_keys,
|
||||
read_keys=read_keys,
|
||||
write_keys=write_keys,
|
||||
)
|
||||
|
||||
# Build per-node accounts prompt (filtered to this node's tools)
|
||||
@@ -1812,6 +1903,9 @@ class GraphExecutor:
|
||||
dynamic_tools_provider=self.dynamic_tools_provider,
|
||||
dynamic_prompt_provider=self.dynamic_prompt_provider,
|
||||
iteration_metadata_provider=self.iteration_metadata_provider,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
@@ -2052,6 +2146,10 @@ class GraphExecutor:
|
||||
edge=edge,
|
||||
)
|
||||
|
||||
# Track which branch wrote which key for memory conflict detection
|
||||
fanout_written_keys: dict[str, str] = {} # key -> branch_id that wrote it
|
||||
fanout_keys_lock = asyncio.Lock()
|
||||
|
||||
self.logger.info(f" ⑂ Fan-out: executing {len(branches)} branches in parallel")
|
||||
for branch in branches.values():
|
||||
target_spec = graph.get_node(branch.node_id)
|
||||
@@ -2143,8 +2241,31 @@ class GraphExecutor:
|
||||
)
|
||||
|
||||
if result.success:
|
||||
# Write outputs to shared memory using async write
|
||||
# Write outputs to shared memory with conflict detection
|
||||
conflict_strategy = self._parallel_config.memory_conflict_strategy
|
||||
for key, value in result.output.items():
|
||||
async with fanout_keys_lock:
|
||||
prior_branch = fanout_written_keys.get(key)
|
||||
if prior_branch and prior_branch != branch.branch_id:
|
||||
if conflict_strategy == "error":
|
||||
raise RuntimeError(
|
||||
f"Memory conflict: key '{key}' already written "
|
||||
f"by branch '{prior_branch}', "
|
||||
f"conflicting write from '{branch.branch_id}'"
|
||||
)
|
||||
elif conflict_strategy == "first_wins":
|
||||
self.logger.debug(
|
||||
f" ⚠ Skipping write to '{key}' "
|
||||
f"(first_wins: already set by {prior_branch})"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# last_wins (default): write and log
|
||||
self.logger.debug(
|
||||
f" ⚠ Key '{key}' overwritten "
|
||||
f"(last_wins: {prior_branch} -> {branch.branch_id})"
|
||||
)
|
||||
fanout_written_keys[key] = branch.branch_id
|
||||
await memory.write_async(key, value)
|
||||
|
||||
branch.result = result
|
||||
@@ -2191,9 +2312,11 @@ class GraphExecutor:
|
||||
|
||||
return branch, e
|
||||
|
||||
# Execute all branches concurrently
|
||||
tasks = [execute_single_branch(b) for b in branches.values()]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
# Execute all branches concurrently with per-branch timeout
|
||||
timeout = self._parallel_config.branch_timeout_seconds
|
||||
branch_list = list(branches.values())
|
||||
tasks = [asyncio.wait_for(execute_single_branch(b), timeout=timeout) for b in branch_list]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results
|
||||
total_tokens = 0
|
||||
@@ -2201,17 +2324,33 @@ class GraphExecutor:
|
||||
branch_results: dict[str, NodeResult] = {}
|
||||
failed_branches: list[ParallelBranch] = []
|
||||
|
||||
for branch, result in results:
|
||||
path.append(branch.node_id)
|
||||
for i, result in enumerate(results):
|
||||
branch = branch_list[i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
if isinstance(result, asyncio.TimeoutError):
|
||||
# Branch timed out
|
||||
branch.status = "timed_out"
|
||||
branch.error = f"Branch timed out after {timeout}s"
|
||||
self.logger.warning(
|
||||
f" ⏱ Branch {graph.get_node(branch.node_id).name}: "
|
||||
f"timed out after {timeout}s"
|
||||
)
|
||||
path.append(branch.node_id)
|
||||
failed_branches.append(branch)
|
||||
elif result is None or not result.success:
|
||||
elif isinstance(result, Exception):
|
||||
path.append(branch.node_id)
|
||||
failed_branches.append(branch)
|
||||
else:
|
||||
total_tokens += result.tokens_used
|
||||
total_latency += result.latency_ms
|
||||
branch_results[branch.branch_id] = result
|
||||
returned_branch, node_result = result
|
||||
path.append(returned_branch.node_id)
|
||||
if node_result is None or isinstance(node_result, Exception):
|
||||
failed_branches.append(returned_branch)
|
||||
elif not node_result.success:
|
||||
failed_branches.append(returned_branch)
|
||||
else:
|
||||
total_tokens += node_result.tokens_used
|
||||
total_latency += node_result.latency_ms
|
||||
branch_results[returned_branch.branch_id] = node_result
|
||||
|
||||
# Handle failures based on config
|
||||
if failed_branches:
|
||||
|
||||
@@ -565,6 +565,11 @@ class NodeContext:
|
||||
# staging / running) without restarting the conversation.
|
||||
dynamic_prompt_provider: Any = None # Callable[[], str] | None
|
||||
|
||||
# Skill system prompts — injected by the skill discovery pipeline
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
protocols_prompt: str = "" # Default skill operational protocols
|
||||
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
|
||||
|
||||
# Per-iteration metadata provider — when set, EventLoopNode merges
|
||||
# the returned dict into node_loop_iteration event data. Used by
|
||||
# the queen to record the current phase per iteration.
|
||||
|
||||
@@ -26,6 +26,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Injected into every worker node's system prompt so the LLM understands
|
||||
# it is one step in a multi-node pipeline and should not overreach.
|
||||
EXECUTION_SCOPE_PREAMBLE = (
|
||||
"EXECUTION SCOPE: You are one node in a multi-step workflow graph. "
|
||||
"Focus ONLY on the task described in your instructions below. "
|
||||
"Call set_output() for each of your declared output keys, then stop. "
|
||||
"Do NOT attempt work that belongs to other nodes — the framework "
|
||||
"routes data between nodes automatically."
|
||||
)
|
||||
|
||||
|
||||
def _with_datetime(prompt: str) -> str:
|
||||
"""Append current datetime with local timezone to a system prompt."""
|
||||
@@ -140,14 +150,24 @@ def compose_system_prompt(
|
||||
focus_prompt: str | None,
|
||||
narrative: str | None = None,
|
||||
accounts_prompt: str | None = None,
|
||||
skills_catalog_prompt: str | None = None,
|
||||
protocols_prompt: str | None = None,
|
||||
execution_preamble: str | None = None,
|
||||
node_type_preamble: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the three-layer system prompt.
|
||||
"""Compose the multi-layer system prompt.
|
||||
|
||||
Args:
|
||||
identity_prompt: Layer 1 — static agent identity (from GraphSpec).
|
||||
focus_prompt: Layer 3 — per-node focus directive (from NodeSpec.system_prompt).
|
||||
narrative: Layer 2 — auto-generated from conversation state.
|
||||
accounts_prompt: Connected accounts block (sits between identity and narrative).
|
||||
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
|
||||
protocols_prompt: Default skill operational protocols section.
|
||||
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
|
||||
(prepended before focus so the LLM knows its pipeline scope).
|
||||
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
|
||||
best-practices prompt (prepended before focus).
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present, plus current datetime.
|
||||
@@ -162,10 +182,27 @@ def compose_system_prompt(
|
||||
if accounts_prompt:
|
||||
parts.append(f"\n{accounts_prompt}")
|
||||
|
||||
# Skills catalog (discovered skills available for activation)
|
||||
if skills_catalog_prompt:
|
||||
parts.append(f"\n{skills_catalog_prompt}")
|
||||
|
||||
# Operational protocols (default skill behavioral guidance)
|
||||
if protocols_prompt:
|
||||
parts.append(f"\n{protocols_prompt}")
|
||||
|
||||
# Layer 2: Narrative (what's happened so far)
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
|
||||
# Execution scope preamble (worker nodes — tells the LLM it is one
|
||||
# step in a multi-node pipeline and should not overreach)
|
||||
if execution_preamble:
|
||||
parts.append(f"\n{execution_preamble}")
|
||||
|
||||
# Node-type preamble (e.g. GCU browser best-practices)
|
||||
if node_type_preamble:
|
||||
parts.append(f"\n{node_type_preamble}")
|
||||
|
||||
# Layer 3: Focus (current phase directive)
|
||||
if focus_prompt:
|
||||
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
|
||||
@@ -255,7 +292,9 @@ def build_transition_marker(
|
||||
sections.append(f"\nCompleted: {previous_node.name}")
|
||||
sections.append(f" {previous_node.description}")
|
||||
|
||||
# Outputs in memory
|
||||
# Outputs in memory — use file references for large values so the
|
||||
# next node loads full data from disk instead of seeing truncated
|
||||
# inline previews that look deceptively complete.
|
||||
all_memory = memory.read_all()
|
||||
if all_memory:
|
||||
memory_lines: list[str] = []
|
||||
@@ -263,7 +302,29 @@ def build_transition_marker(
|
||||
if value is None:
|
||||
continue
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
if len(val_str) > 300 and data_dir:
|
||||
# Auto-spill large transition values to data files
|
||||
import json as _json
|
||||
|
||||
data_path = Path(data_dir)
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
try:
|
||||
write_content = (
|
||||
_json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(data_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (data_path / filename).stat().st_size
|
||||
val_str = (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') to access.]"
|
||||
)
|
||||
except Exception:
|
||||
val_str = val_str[:300] + "..."
|
||||
elif len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
memory_lines.append(f" {key}: {val_str}")
|
||||
if memory_lines:
|
||||
@@ -280,7 +341,7 @@ def build_transition_marker(
|
||||
]
|
||||
if file_lines:
|
||||
sections.append(
|
||||
"\nData files (use read_file to access):\n" + "\n".join(file_lines)
|
||||
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
|
||||
)
|
||||
|
||||
# Agent working memory
|
||||
@@ -294,6 +355,12 @@ def build_transition_marker(
|
||||
# Next phase
|
||||
sections.append(f"\nNow entering: {next_node.name}")
|
||||
sections.append(f" {next_node.description}")
|
||||
if next_node.output_keys:
|
||||
sections.append(
|
||||
f"\nYour ONLY job in this phase: complete the task above and call "
|
||||
f"set_output() for {next_node.output_keys}. Do NOT do work that "
|
||||
f"belongs to later phases."
|
||||
)
|
||||
|
||||
# Reflection prompt (engineered metacognition)
|
||||
sections.append(
|
||||
|
||||
@@ -115,11 +115,23 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
return True
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
|
||||
values = [self.visit(v) for v in node.values]
|
||||
# Short-circuit evaluation to match Python semantics.
|
||||
# Previously all operands were eagerly evaluated, which broke
|
||||
# guard patterns like: ``x is not None and x.get("key")``
|
||||
if isinstance(node.op, ast.And):
|
||||
return all(values)
|
||||
result = True
|
||||
for v in node.values:
|
||||
result = self.visit(v)
|
||||
if not result:
|
||||
return result
|
||||
return result
|
||||
elif isinstance(node.op, ast.Or):
|
||||
return any(values)
|
||||
result = False
|
||||
for v in node.values:
|
||||
result = self.visit(v)
|
||||
if result:
|
||||
return result
|
||||
return result
|
||||
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
|
||||
|
||||
def visit_IfExp(self, node: ast.IfExp) -> Any:
|
||||
|
||||
+695
-12
@@ -7,9 +7,13 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
@@ -23,6 +27,7 @@ except ImportError:
|
||||
litellm = None # type: ignore[assignment]
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.config import HIVE_LLM_ENDPOINT as HIVE_API_BASE
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
@@ -43,8 +48,17 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
"""
|
||||
try:
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_OAUTH_TOKEN_PREFIX
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_OAUTH_BETA_HEADER,
|
||||
ANTHROPIC_OAUTH_TOKEN_PREFIX,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Could not apply litellm Anthropic OAuth patch — litellm internals may have "
|
||||
"changed. Anthropic OAuth tokens (Claude Code subscriptions) may fail with 401. "
|
||||
"See BerriAI/litellm#19618. Current litellm version: %s",
|
||||
getattr(litellm, "__version__", "unknown"),
|
||||
)
|
||||
return
|
||||
|
||||
original = AnthropicModelInfo.validate_environment
|
||||
@@ -62,9 +76,27 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
# Check both authorization header and x-api-key for OAuth tokens.
|
||||
# litellm's optionally_handle_anthropic_oauth only checks headers["authorization"],
|
||||
# but hive passes OAuth tokens via api_key — so litellm puts them into x-api-key.
|
||||
# Anthropic rejects OAuth tokens in x-api-key; they must go in Authorization: Bearer.
|
||||
auth = result.get("authorization", "")
|
||||
if auth.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
|
||||
x_api_key = result.get("x-api-key", "")
|
||||
oauth_prefix = f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"
|
||||
auth_is_oauth = auth.startswith(oauth_prefix)
|
||||
key_is_oauth = x_api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
|
||||
if auth_is_oauth or key_is_oauth:
|
||||
token = x_api_key if key_is_oauth else auth.removeprefix("Bearer ").strip()
|
||||
result.pop("x-api-key", None)
|
||||
result["authorization"] = f"Bearer {token}"
|
||||
# Merge the OAuth beta header with any existing beta headers.
|
||||
existing_beta = result.get("anthropic-beta", "")
|
||||
beta_parts = (
|
||||
[b.strip() for b in existing_beta.split(",") if b.strip()] if existing_beta else []
|
||||
)
|
||||
if ANTHROPIC_OAUTH_BETA_HEADER not in beta_parts:
|
||||
beta_parts.append(ANTHROPIC_OAUTH_BETA_HEADER)
|
||||
result["anthropic-beta"] = ",".join(beta_parts)
|
||||
return result
|
||||
|
||||
AnthropicModelInfo.validate_environment = _patched_validate_environment
|
||||
@@ -86,10 +118,12 @@ def _patch_litellm_metadata_nonetype() -> None:
|
||||
"""
|
||||
import functools
|
||||
|
||||
patched_count = 0
|
||||
for fn_name in ("completion", "acompletion", "responses", "aresponses"):
|
||||
original = getattr(litellm, fn_name, None)
|
||||
if original is None:
|
||||
continue
|
||||
patched_count += 1
|
||||
if asyncio.iscoroutinefunction(original):
|
||||
|
||||
@functools.wraps(original)
|
||||
@@ -109,15 +143,27 @@ def _patch_litellm_metadata_nonetype() -> None:
|
||||
|
||||
setattr(litellm, fn_name, _sync_wrapper)
|
||||
|
||||
if patched_count == 0:
|
||||
logger.warning(
|
||||
"Could not apply litellm metadata=None patch — none of the expected entry "
|
||||
"points (completion, acompletion, responses, aresponses) were found. "
|
||||
"metadata=None TypeError may occur. Current litellm version: %s",
|
||||
getattr(litellm, "__version__", "unknown"),
|
||||
)
|
||||
|
||||
|
||||
if litellm is not None:
|
||||
_patch_litellm_anthropic_oauth()
|
||||
_patch_litellm_metadata_nonetype()
|
||||
# Let litellm silently drop params unsupported by the target provider
|
||||
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
|
||||
litellm.drop_params = True
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
MINIMAX_API_BASE = "https://api.minimax.io/v1"
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
|
||||
# Providers that accept cache_control on message content blocks.
|
||||
# Anthropic: native ephemeral caching. MiniMax & Z-AI/GLM: pass-through to their APIs.
|
||||
@@ -142,14 +188,77 @@ def _model_supports_cache_control(model: str) -> bool:
|
||||
# enforces a coding-agent whitelist that blocks unknown User-Agents.
|
||||
KIMI_API_BASE = "https://api.kimi.com/coding"
|
||||
|
||||
# Claude Code OAuth subscription: the Anthropic API requires a specific
|
||||
# User-Agent and a billing integrity header for OAuth-authenticated requests.
|
||||
CLAUDE_CODE_VERSION = "2.1.76"
|
||||
CLAUDE_CODE_USER_AGENT = f"claude-code/{CLAUDE_CODE_VERSION}"
|
||||
_CLAUDE_CODE_BILLING_SALT = "59cf53e54c78"
|
||||
|
||||
|
||||
def _sample_js_code_unit(text: str, idx: int) -> str:
|
||||
"""Return the character at UTF-16 code unit index *idx*, matching JS semantics."""
|
||||
encoded = text.encode("utf-16-le")
|
||||
unit_offset = idx * 2
|
||||
if unit_offset + 2 > len(encoded):
|
||||
return "0"
|
||||
code_unit = int.from_bytes(encoded[unit_offset : unit_offset + 2], "little")
|
||||
return chr(code_unit)
|
||||
|
||||
|
||||
def _claude_code_billing_header(messages: list[dict[str, Any]]) -> str:
|
||||
"""Build the billing integrity system block required by Anthropic's OAuth path."""
|
||||
# Find the first user message text
|
||||
first_text = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
first_text = content
|
||||
break
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text" and block.get("text"):
|
||||
first_text = block["text"]
|
||||
break
|
||||
if first_text:
|
||||
break
|
||||
|
||||
sampled = "".join(_sample_js_code_unit(first_text, i) for i in (4, 7, 20))
|
||||
version_hash = hashlib.sha256(
|
||||
f"{_CLAUDE_CODE_BILLING_SALT}{sampled}{CLAUDE_CODE_VERSION}".encode()
|
||||
).hexdigest()
|
||||
entrypoint = os.environ.get("CLAUDE_CODE_ENTRYPOINT", "").strip() or "cli"
|
||||
return (
|
||||
f"x-anthropic-billing-header: cc_version={CLAUDE_CODE_VERSION}.{version_hash[:3]}; "
|
||||
f"cc_entrypoint={entrypoint}; cch=00000;"
|
||||
)
|
||||
|
||||
|
||||
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
|
||||
# Conversation-structure issues are deterministic — long waits don't help.
|
||||
EMPTY_STREAM_MAX_RETRIES = 3
|
||||
EMPTY_STREAM_RETRY_DELAY = 1.0 # seconds
|
||||
OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS = (
|
||||
"no endpoints found that support tool use",
|
||||
"no endpoints available that support tool use",
|
||||
"provider routing",
|
||||
)
|
||||
OPENROUTER_TOOL_CALL_RE = re.compile(
|
||||
r"<\|tool_call_start\|>\s*(.*?)\s*<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS = 3600
|
||||
# OpenRouter routing can change over time, so tool-compat caching must expire.
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE: dict[str, float] = {}
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
|
||||
# Maximum number of dump files to retain in ~/.hive/failed_requests/.
|
||||
# Older files are pruned automatically to prevent unbounded disk growth.
|
||||
MAX_FAILED_REQUEST_DUMPS = 50
|
||||
|
||||
|
||||
def _estimate_tokens(model: str, messages: list[dict]) -> tuple[int, str]:
|
||||
"""Estimate token count for messages. Returns (token_count, method)."""
|
||||
@@ -166,6 +275,42 @@ def _estimate_tokens(model: str, messages: list[dict]) -> tuple[int, str]:
|
||||
return total_chars // 4, "estimate"
|
||||
|
||||
|
||||
def _prune_failed_request_dumps(max_files: int = MAX_FAILED_REQUEST_DUMPS) -> None:
|
||||
"""Remove oldest dump files when the count exceeds *max_files*.
|
||||
|
||||
Best-effort: never raises — a pruning failure must not break retry logic.
|
||||
"""
|
||||
try:
|
||||
all_dumps = sorted(
|
||||
FAILED_REQUESTS_DIR.glob("*.json"),
|
||||
key=lambda f: f.stat().st_mtime,
|
||||
)
|
||||
excess = len(all_dumps) - max_files
|
||||
if excess > 0:
|
||||
for old_file in all_dumps[:excess]:
|
||||
old_file.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass # Best-effort — never block the caller
|
||||
|
||||
|
||||
def _remember_openrouter_tool_compat_model(model: str) -> None:
|
||||
"""Cache OpenRouter tool-compat fallback for a bounded time window."""
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE[model] = (
|
||||
time.monotonic() + OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS
|
||||
)
|
||||
|
||||
|
||||
def _is_openrouter_tool_compat_cached(model: str) -> bool:
|
||||
"""Return True when the cached OpenRouter compat entry is still fresh."""
|
||||
expires_at = OPENROUTER_TOOL_COMPAT_MODEL_CACHE.get(model)
|
||||
if expires_at is None:
|
||||
return False
|
||||
if expires_at <= time.monotonic():
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.pop(model, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _dump_failed_request(
|
||||
model: str,
|
||||
kwargs: dict[str, Any],
|
||||
@@ -197,6 +342,9 @@ def _dump_failed_request(
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(dump_data, f, indent=2, default=str)
|
||||
|
||||
# Prune old dumps to prevent unbounded disk growth
|
||||
_prune_failed_request_dumps()
|
||||
|
||||
return str(filepath)
|
||||
|
||||
|
||||
@@ -358,10 +506,20 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Strip a trailing /v1 in case the user's saved config has the old value.
|
||||
if api_base and api_base.rstrip("/").endswith("/v1"):
|
||||
api_base = api_base.rstrip("/")[:-3]
|
||||
elif model.lower().startswith("hive/"):
|
||||
model = "anthropic/" + model[len("hive/") :]
|
||||
if api_base and api_base.rstrip("/").endswith("/v1"):
|
||||
api_base = api_base.rstrip("/")[:-3]
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or self._default_api_base_for_model(_original_model)
|
||||
self.extra_kwargs = kwargs
|
||||
# Detect Claude Code OAuth subscription by checking the api_key prefix.
|
||||
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
|
||||
if self._claude_code_oauth:
|
||||
# Anthropic requires a specific User-Agent for OAuth requests.
|
||||
eh = self.extra_kwargs.setdefault("extra_headers", {})
|
||||
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
|
||||
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
|
||||
# several standard OpenAI params: max_output_tokens, stream_options.
|
||||
self._codex_backend = bool(
|
||||
@@ -385,8 +543,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
model_lower = model.lower()
|
||||
if model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
|
||||
return MINIMAX_API_BASE
|
||||
if model_lower.startswith("openrouter/"):
|
||||
return OPENROUTER_API_BASE
|
||||
if model_lower.startswith("kimi/"):
|
||||
return KIMI_API_BASE
|
||||
if model_lower.startswith("hive/"):
|
||||
return HIVE_API_BASE
|
||||
return None
|
||||
|
||||
def _completion_with_rate_limit_retry(
|
||||
@@ -725,6 +887,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -786,11 +951,504 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
}
|
||||
|
||||
def _is_anthropic_model(self) -> bool:
|
||||
"""Return True when the configured model targets Anthropic."""
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("anthropic/") or model.startswith("claude-")
|
||||
|
||||
def _is_minimax_model(self) -> bool:
|
||||
"""Return True when the configured model targets MiniMax."""
|
||||
model = (self.model or "").lower()
|
||||
return model.startswith("minimax/") or model.startswith("minimax-")
|
||||
|
||||
def _is_openrouter_model(self) -> bool:
|
||||
"""Return True when the configured model targets OpenRouter."""
|
||||
model = (self.model or "").lower()
|
||||
if model.startswith("openrouter/"):
|
||||
return True
|
||||
api_base = (self.api_base or "").lower()
|
||||
return "openrouter.ai/api/v1" in api_base
|
||||
|
||||
def _should_use_openrouter_tool_compat(
|
||||
self,
|
||||
error: BaseException,
|
||||
tools: list[Tool] | None,
|
||||
) -> bool:
|
||||
"""Return True when OpenRouter rejects native tool use for the model."""
|
||||
if not tools or not self._is_openrouter_model():
|
||||
return False
|
||||
error_text = str(error).lower()
|
||||
return "openrouter" in error_text and any(
|
||||
snippet in error_text for snippet in OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||
"""Extract the first JSON object from a model response."""
|
||||
candidates = [text.strip()]
|
||||
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
fence_lines = stripped.splitlines()
|
||||
if len(fence_lines) >= 3:
|
||||
candidates.append("\n".join(fence_lines[1:-1]).strip())
|
||||
|
||||
decoder = json.JSONDecoder()
|
||||
for candidate in candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
for start_idx, char in enumerate(candidate):
|
||||
if char != "{":
|
||||
continue
|
||||
try:
|
||||
parsed, _ = decoder.raw_decode(candidate[start_idx:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_openrouter_tool_compat_response(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse JSON tool-compat output into assistant text and tool calls."""
|
||||
payload = self._extract_json_object(content)
|
||||
if payload is None:
|
||||
text_tool_content, text_tool_calls = self._parse_openrouter_text_tool_calls(
|
||||
content,
|
||||
tools,
|
||||
)
|
||||
if text_tool_calls:
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Parsed textual tool-call markers for %s",
|
||||
self.model,
|
||||
)
|
||||
return text_tool_content, text_tool_calls
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] %s returned non-JSON fallback content; "
|
||||
"treating it as plain text.",
|
||||
self.model,
|
||||
)
|
||||
return content.strip(), []
|
||||
|
||||
assistant_text = payload.get("assistant_response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("content")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = payload.get("response")
|
||||
if not isinstance(assistant_text, str):
|
||||
assistant_text = ""
|
||||
|
||||
tool_calls_raw = payload.get("tool_calls")
|
||||
if not tool_calls_raw and {"name", "arguments"} <= payload.keys():
|
||||
tool_calls_raw = [payload]
|
||||
elif isinstance(payload.get("tool_call"), dict):
|
||||
tool_calls_raw = [payload["tool_call"]]
|
||||
|
||||
if not isinstance(tool_calls_raw, list):
|
||||
tool_calls_raw = []
|
||||
|
||||
allowed_tool_names = {tool.name for tool in tools}
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
|
||||
for idx, raw_call in enumerate(tool_calls_raw):
|
||||
if not isinstance(raw_call, dict):
|
||||
continue
|
||||
|
||||
function_block = raw_call.get("function")
|
||||
function_name = (
|
||||
raw_call.get("name")
|
||||
or raw_call.get("tool_name")
|
||||
or (function_block.get("name") if isinstance(function_block, dict) else None)
|
||||
)
|
||||
if not isinstance(function_name, str) or function_name not in allowed_tool_names:
|
||||
if function_name:
|
||||
logger.warning(
|
||||
"[openrouter-tool-compat] Ignoring unknown tool '%s' for model %s",
|
||||
function_name,
|
||||
self.model,
|
||||
)
|
||||
continue
|
||||
|
||||
arguments = raw_call.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("tool_input")
|
||||
if arguments is None:
|
||||
arguments = raw_call.get("input")
|
||||
if arguments is None and isinstance(function_block, dict):
|
||||
arguments = function_block.get("arguments")
|
||||
if arguments is None:
|
||||
arguments = {}
|
||||
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"_raw": arguments}
|
||||
elif not isinstance(arguments, dict):
|
||||
arguments = {"value": arguments}
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{idx}",
|
||||
"name": function_name,
|
||||
"input": arguments,
|
||||
}
|
||||
)
|
||||
|
||||
return assistant_text.strip(), tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _close_truncated_json_fragment(fragment: str) -> str:
|
||||
"""Close a truncated JSON fragment by balancing quotes/brackets."""
|
||||
stack: list[str] = []
|
||||
in_string = False
|
||||
escaped = False
|
||||
normalized = fragment.rstrip()
|
||||
|
||||
while normalized and normalized[-1] in ",:{[":
|
||||
normalized = normalized[:-1].rstrip()
|
||||
|
||||
for char in normalized:
|
||||
if in_string:
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif char == "\\":
|
||||
escaped = True
|
||||
elif char == '"':
|
||||
in_string = False
|
||||
continue
|
||||
|
||||
if char == '"':
|
||||
in_string = True
|
||||
elif char in "{[":
|
||||
stack.append(char)
|
||||
elif char == "}" and stack and stack[-1] == "{":
|
||||
stack.pop()
|
||||
elif char == "]" and stack and stack[-1] == "[":
|
||||
stack.pop()
|
||||
|
||||
if in_string:
|
||||
if escaped:
|
||||
normalized = normalized[:-1]
|
||||
normalized += '"'
|
||||
|
||||
for opener in reversed(stack):
|
||||
normalized += "}" if opener == "{" else "]"
|
||||
|
||||
return normalized
|
||||
|
||||
def _repair_truncated_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None:
|
||||
"""Try to recover a truncated JSON object from tool-call arguments."""
|
||||
stripped = raw_arguments.strip()
|
||||
if not stripped or stripped[0] != "{":
|
||||
return None
|
||||
|
||||
max_trim = min(len(stripped), 256)
|
||||
for trim in range(max_trim + 1):
|
||||
candidate = stripped[: len(stripped) - trim].rstrip()
|
||||
if not candidate:
|
||||
break
|
||||
candidate = self._close_truncated_json_fragment(candidate)
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]:
|
||||
"""Parse streamed tool arguments, repairing truncation when possible."""
|
||||
try:
|
||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
repaired = self._repair_truncated_tool_arguments(raw_arguments)
|
||||
if repaired is not None:
|
||||
logger.warning(
|
||||
"[tool-args] Recovered truncated arguments for %s on %s",
|
||||
tool_name,
|
||||
self.model,
|
||||
)
|
||||
return repaired
|
||||
|
||||
raise ValueError(
|
||||
f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)."
|
||||
)
|
||||
|
||||
def _parse_openrouter_text_tool_calls(
|
||||
self,
|
||||
content: str,
|
||||
tools: list[Tool],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Parse textual OpenRouter tool calls into synthetic tool calls.
|
||||
|
||||
Supports both:
|
||||
- Marker wrapped payloads: <|tool_call_start|>...<|tool_call_end|>
|
||||
- Plain one-line tool calls: ask_user("...", ["..."])
|
||||
"""
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
compat_prefix = f"openrouter_compat_{time.time_ns()}"
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
segment_index = 0
|
||||
|
||||
for match in OPENROUTER_TOOL_CALL_RE.finditer(content):
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=match.group(1),
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
|
||||
stripped_content = OPENROUTER_TOOL_CALL_RE.sub("", content)
|
||||
retained_lines: list[str] = []
|
||||
for line in stripped_content.splitlines():
|
||||
stripped_line = line.strip()
|
||||
if not stripped_line:
|
||||
retained_lines.append(line)
|
||||
continue
|
||||
|
||||
candidate = stripped_line
|
||||
if candidate.startswith("`") and candidate.endswith("`") and len(candidate) > 1:
|
||||
candidate = candidate[1:-1].strip()
|
||||
|
||||
parsed_calls = self._parse_openrouter_text_tool_call_block(
|
||||
block=candidate,
|
||||
tools_by_name=tools_by_name,
|
||||
compat_prefix=f"{compat_prefix}_{segment_index}",
|
||||
)
|
||||
if parsed_calls:
|
||||
segment_index += 1
|
||||
tool_calls.extend(parsed_calls)
|
||||
continue
|
||||
|
||||
retained_lines.append(line)
|
||||
|
||||
stripped_text = "\n".join(retained_lines).strip()
|
||||
return stripped_text, tool_calls
|
||||
|
||||
def _parse_openrouter_text_tool_call_block(
|
||||
self,
|
||||
block: str,
|
||||
tools_by_name: dict[str, Tool],
|
||||
compat_prefix: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse a single textual tool-call block like [tool(arg='x')]."""
|
||||
try:
|
||||
parsed = ast.parse(block.strip(), mode="eval").body
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
call_nodes = parsed.elts if isinstance(parsed, ast.List) else [parsed]
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for call_index, call_node in enumerate(call_nodes):
|
||||
if not isinstance(call_node, ast.Call) or not isinstance(call_node.func, ast.Name):
|
||||
continue
|
||||
|
||||
tool_name = call_node.func.id
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_input = self._parse_openrouter_text_tool_call_arguments(
|
||||
call_node=call_node,
|
||||
tool=tool,
|
||||
)
|
||||
except (ValueError, SyntaxError):
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"{compat_prefix}_{call_index}",
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _parse_openrouter_text_tool_call_arguments(
|
||||
call_node: ast.Call,
|
||||
tool: Tool,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse positional/keyword args from a textual tool call."""
|
||||
properties = tool.parameters.get("properties", {})
|
||||
positional_keys = list(properties.keys())
|
||||
tool_input: dict[str, Any] = {}
|
||||
|
||||
if len(call_node.args) > len(positional_keys):
|
||||
raise ValueError("Too many positional args for textual tool call")
|
||||
|
||||
for idx, arg_node in enumerate(call_node.args):
|
||||
tool_input[positional_keys[idx]] = ast.literal_eval(arg_node)
|
||||
|
||||
for kwarg in call_node.keywords:
|
||||
if kwarg.arg is None:
|
||||
raise ValueError("Star args are not supported in textual tool calls")
|
||||
tool_input[kwarg.arg] = ast.literal_eval(kwarg.value)
|
||||
|
||||
return tool_input
|
||||
|
||||
def _build_openrouter_tool_compat_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a JSON-only prompt for models without native tool support."""
|
||||
tool_specs = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
compat_instruction = (
|
||||
"Tool compatibility mode is active because this OpenRouter model does not support "
|
||||
"native function calling on the routed provider.\n"
|
||||
"Return exactly one JSON object and nothing else.\n"
|
||||
'Schema: {"assistant_response": string, '
|
||||
'"tool_calls": [{"name": string, "arguments": object}]}\n'
|
||||
"Rules:\n"
|
||||
"- If a tool is required, put one or more entries in tool_calls "
|
||||
"and do not invent tool results.\n"
|
||||
"- If no tool is required, set tool_calls to [] and put the full "
|
||||
"answer in assistant_response.\n"
|
||||
"- Only use tool names from the allowed tool list.\n"
|
||||
"- arguments must always be valid JSON objects.\n"
|
||||
f"Allowed tools:\n{json.dumps(tool_specs, ensure_ascii=True)}"
|
||||
)
|
||||
compat_system = compat_instruction if not system else f"{system}\n\n{compat_instruction}"
|
||||
|
||||
full_messages: list[dict[str, Any]] = [{"role": "system", "content": compat_system}]
|
||||
full_messages.extend(messages)
|
||||
return [
|
||||
message
|
||||
for message in full_messages
|
||||
if not (
|
||||
message.get("role") == "assistant"
|
||||
and not message.get("content")
|
||||
and not message.get("tool_calls")
|
||||
)
|
||||
]
|
||||
|
||||
async def _acomplete_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools."""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(messages, system, tools)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
assistant_text, tool_calls = self._parse_openrouter_tool_compat_response(
|
||||
raw_content,
|
||||
tools,
|
||||
)
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
stop_reason = "tool_calls" if tool_calls else (response.choices[0].finish_reason or "stop")
|
||||
|
||||
return LLMResponse(
|
||||
content=assistant_text,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={
|
||||
"compat_mode": "openrouter_tool_emulation",
|
||||
"tool_calls": tool_calls,
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
|
||||
async def _stream_via_openrouter_tool_compat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback stream for OpenRouter models without native tool support."""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[openrouter-tool-compat] Using compatibility mode for %s",
|
||||
self.model,
|
||||
)
|
||||
try:
|
||||
response = await self._acomplete_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
raw_response = response.raw_response if isinstance(response.raw_response, dict) else {}
|
||||
tool_calls = raw_response.get("tool_calls", [])
|
||||
|
||||
if response.content:
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=tool_call["id"],
|
||||
tool_name=tool_call["name"],
|
||||
tool_input=tool_call["input"],
|
||||
)
|
||||
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _stream_via_nonstream_completion(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -834,12 +1492,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
tool_calls = msg.tool_calls or []
|
||||
|
||||
for tc in tool_calls:
|
||||
parsed_args: Any
|
||||
args = tc.function.arguments if tc.function else ""
|
||||
try:
|
||||
parsed_args = json.loads(args) if args else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed_args = {"_raw": args}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
args,
|
||||
tc.function.name if tc.function else "",
|
||||
)
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=getattr(tc, "id", ""),
|
||||
tool_name=tc.function.name if tc.function else "",
|
||||
@@ -898,7 +1555,20 @@ class LiteLLMProvider(LLMProvider):
|
||||
yield event
|
||||
return
|
||||
|
||||
if tools and self._is_openrouter_model() and _is_openrouter_tool_compat_cached(self.model):
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
@@ -936,9 +1606,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
|
||||
# Only include it for providers that support it.
|
||||
if not self._is_anthropic_model():
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
@@ -1044,10 +1717,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
if choice.finish_reason:
|
||||
stream_finish_reason = choice.finish_reason
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
tc_data.get("arguments", ""),
|
||||
tc_data.get("name", ""),
|
||||
)
|
||||
tail_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
@@ -1228,6 +1901,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if self._should_use_openrouter_tool_compat(e, tools):
|
||||
_remember_openrouter_tool_compat_model(self.model)
|
||||
async for event in self._stream_via_openrouter_tool_compat(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools or [],
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
if _is_stream_transient_error(e) and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = _compute_retry_delay(attempt, exception=e)
|
||||
logger.warning(
|
||||
|
||||
@@ -45,6 +45,7 @@ class ToolResult:
|
||||
tool_use_id: str
|
||||
content: str
|
||||
is_error: bool = False
|
||||
is_skill_content: bool = False # AS-10: marks activated skill body, protected from pruning
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
|
||||
@@ -83,18 +83,18 @@ configure_logging(level="INFO", format="auto")
|
||||
- Compact single-line format (easy to stream/parse)
|
||||
- All trace context fields included automatically
|
||||
|
||||
### Human-Readable Format (Development)
|
||||
### Human-Readable Format (Development / Terminal)
|
||||
|
||||
```
|
||||
[INFO ] [trace:12345678 | exec:a1b2c3d4 | agent:sales-agent] Starting agent execution
|
||||
[INFO ] [trace:12345678 | exec:a1b2c3d4 | agent:sales-agent] Processing input data [node_id:input-processor]
|
||||
[INFO ] [trace:12345678 | exec:a1b2c3d4 | agent:sales-agent] LLM call completed [latency_ms:1250] [tokens_used:450]
|
||||
[INFO ] [agent:sales-agent] Starting agent execution
|
||||
[INFO ] [agent:sales-agent] Processing input data [node_id:input-processor]
|
||||
[INFO ] [agent:sales-agent] LLM call completed [latency_ms:1250] [tokens_used:450]
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Color-coded log levels
|
||||
- Shortened IDs for readability (first 8 chars)
|
||||
- Context prefix shows trace correlation
|
||||
- Terminal output omits trace_id and execution_id for readability
|
||||
- For full traceability (e.g. debugging), use `ENV=production` to get JSON file logs with trace_id and execution_id
|
||||
|
||||
## Trace Context Fields
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ Structured logging with automatic trace context propagation.
|
||||
Key Features:
|
||||
- Zero developer friction: Standard logger.info() calls get automatic context
|
||||
- ContextVar-based propagation: Thread-safe and async-safe
|
||||
- Dual output modes: JSON for production, human-readable for development
|
||||
- Correlation IDs: trace_id follows entire request flow automatically
|
||||
- Dual output modes: JSON for production (full trace_id/execution_id), human-readable for terminal
|
||||
- Terminal omits trace_id/execution_id for readability
|
||||
- Use ENV=production for file logs with full traceability
|
||||
|
||||
Architecture:
|
||||
Runtime.start_run() → Generates trace_id, sets context once
|
||||
@@ -101,10 +102,11 @@ class StructuredFormatter(logging.Formatter):
|
||||
|
||||
class HumanReadableFormatter(logging.Formatter):
|
||||
"""
|
||||
Human-readable formatter for development.
|
||||
Human-readable formatter for development (terminal output).
|
||||
|
||||
Provides colorized logs with trace context for local debugging.
|
||||
Includes trace_id prefix for correlation - AUTOMATIC!
|
||||
Provides colorized logs for local debugging. Omits trace_id and execution_id
|
||||
from the terminal for readability; use ENV=production (JSON file logs) when
|
||||
traceability is needed.
|
||||
"""
|
||||
|
||||
COLORS = {
|
||||
@@ -118,18 +120,11 @@ class HumanReadableFormatter(logging.Formatter):
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format log record as human-readable string."""
|
||||
# Get trace context - AUTOMATIC!
|
||||
# Get trace context; omit trace_id and execution_id in terminal for readability
|
||||
context = trace_context.get() or {}
|
||||
trace_id = context.get("trace_id", "")
|
||||
execution_id = context.get("execution_id", "")
|
||||
agent_id = context.get("agent_id", "")
|
||||
|
||||
# Build context prefix
|
||||
prefix_parts = []
|
||||
if trace_id:
|
||||
prefix_parts.append(f"trace:{trace_id[:8]}")
|
||||
if execution_id:
|
||||
prefix_parts.append(f"exec:{execution_id[-8:]}")
|
||||
if agent_id:
|
||||
prefix_parts.append(f"agent:{agent_id}")
|
||||
|
||||
@@ -211,6 +206,15 @@ def configure_logging(
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(level.upper())
|
||||
|
||||
# Suppress noisy LiteLLM INFO logs (model/provider line + Provider List URL
|
||||
# printed on every single completion call). Warnings and errors still show.
|
||||
# Honour LITELLM_LOG env var so users can opt-in to debug output.
|
||||
_litellm_level = os.getenv("LITELLM_LOG", "").upper()
|
||||
if _litellm_level and hasattr(logging, _litellm_level):
|
||||
logging.getLogger("LiteLLM").setLevel(getattr(logging, _litellm_level))
|
||||
else:
|
||||
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
|
||||
|
||||
# When in JSON mode, configure known third-party loggers to use JSON formatter
|
||||
# This ensures libraries like LiteLLM, httpcore also output clean JSON
|
||||
if format == "json":
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""MCP Client for connecting to Model Context Protocol servers.
|
||||
|
||||
This module provides a client for connecting to MCP servers and invoking their tools.
|
||||
Supports both STDIO and HTTP transports using the official MCP Python SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP Python SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -22,7 +22,7 @@ class MCPServerConfig:
|
||||
"""Configuration for an MCP server connection."""
|
||||
|
||||
name: str
|
||||
transport: Literal["stdio", "http"]
|
||||
transport: Literal["stdio", "http", "unix", "sse"]
|
||||
|
||||
# For STDIO transport
|
||||
command: str | None = None
|
||||
@@ -33,6 +33,7 @@ class MCPServerConfig:
|
||||
# For HTTP transport
|
||||
url: str | None = None
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
socket_path: str | None = None
|
||||
|
||||
# Optional metadata
|
||||
description: str = ""
|
||||
@@ -52,7 +53,7 @@ class MCPClient:
|
||||
"""
|
||||
Client for communicating with MCP servers.
|
||||
|
||||
Supports both STDIO and HTTP transports using the official MCP SDK.
|
||||
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP SDK.
|
||||
Manages the connection lifecycle and provides methods to list and invoke tools.
|
||||
"""
|
||||
|
||||
@@ -68,6 +69,7 @@ class MCPClient:
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._stdio_context = None # Context manager for stdio_client
|
||||
self._sse_context = None # Context manager for sse_client
|
||||
self._errlog_handle = None # Track errlog file handle for cleanup
|
||||
self._http_client: httpx.Client | None = None
|
||||
self._tools: dict[str, MCPTool] = {}
|
||||
@@ -141,6 +143,10 @@ class MCPClient:
|
||||
self._connect_stdio()
|
||||
elif self.config.transport == "http":
|
||||
self._connect_http()
|
||||
elif self.config.transport == "unix":
|
||||
self._connect_unix()
|
||||
elif self.config.transport == "sse":
|
||||
self._connect_sse()
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport: {self.config.transport}")
|
||||
|
||||
@@ -266,10 +272,94 @@ class MCPClient:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_unix(self) -> None:
|
||||
"""Connect to MCP server via UNIX domain socket transport."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for UNIX transport")
|
||||
if not self.config.socket_path:
|
||||
raise ValueError("socket_path is required for UNIX transport")
|
||||
|
||||
self._http_client = httpx.Client(
|
||||
base_url=self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
transport=httpx.HTTPTransport(uds=self.config.socket_path),
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
"Connected to MCP server '%s' via UNIX socket at %s",
|
||||
self.config.name,
|
||||
self.config.socket_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
|
||||
# Continue anyway, server might not have health endpoint
|
||||
|
||||
def _connect_sse(self) -> None:
|
||||
"""Connect to MCP server via SSE transport using MCP SDK with persistent session."""
|
||||
if not self.config.url:
|
||||
raise ValueError("url is required for SSE transport")
|
||||
|
||||
try:
|
||||
loop_started = threading.Event()
|
||||
connection_ready = threading.Event()
|
||||
connection_error = []
|
||||
|
||||
def run_event_loop():
|
||||
"""Run event loop in background thread."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
loop_started.set()
|
||||
|
||||
async def init_connection():
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
self._sse_context = sse_client(
|
||||
self.config.url,
|
||||
headers=self.config.headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
(
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
) = await self._sse_context.__aenter__()
|
||||
|
||||
self._session = ClientSession(self._read_stream, self._write_stream)
|
||||
await self._session.__aenter__()
|
||||
await self._session.initialize()
|
||||
|
||||
connection_ready.set()
|
||||
except Exception as e:
|
||||
connection_error.append(e)
|
||||
connection_ready.set()
|
||||
|
||||
self._loop.create_task(init_connection())
|
||||
self._loop.run_forever()
|
||||
|
||||
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
|
||||
self._loop_thread.start()
|
||||
|
||||
loop_started.wait(timeout=5)
|
||||
if not loop_started.is_set():
|
||||
raise RuntimeError("Event loop failed to start")
|
||||
|
||||
connection_ready.wait(timeout=10)
|
||||
if connection_error:
|
||||
raise connection_error[0]
|
||||
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via SSE")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
def _discover_tools(self) -> None:
|
||||
"""Discover available tools from the MCP server."""
|
||||
try:
|
||||
if self.config.transport == "stdio":
|
||||
if self.config.transport in {"stdio", "sse"}:
|
||||
tools_list = self._run_async(self._list_tools_stdio_async())
|
||||
else:
|
||||
tools_list = self._list_tools_http()
|
||||
@@ -371,9 +461,37 @@ class MCPClient:
|
||||
if self.config.transport == "stdio":
|
||||
with self._stdio_call_lock:
|
||||
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
elif self.config.transport == "sse":
|
||||
return self._call_tool_with_retry(
|
||||
lambda: self._run_async(self._call_tool_stdio_async(tool_name, arguments))
|
||||
)
|
||||
elif self.config.transport == "unix":
|
||||
return self._call_tool_with_retry(lambda: self._call_tool_http(tool_name, arguments))
|
||||
else:
|
||||
return self._call_tool_http(tool_name, arguments)
|
||||
|
||||
def _call_tool_with_retry(self, call: Any) -> Any:
|
||||
"""Retry transient MCP transport failures once after reconnecting."""
|
||||
if self.config.transport == "stdio":
|
||||
return call()
|
||||
|
||||
if self.config.transport not in {"unix", "sse"}:
|
||||
return call()
|
||||
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as original_error:
|
||||
logger.warning(
|
||||
"Retrying MCP tool call after transport error from '%s': %s",
|
||||
self.config.name,
|
||||
original_error,
|
||||
)
|
||||
self._reconnect()
|
||||
try:
|
||||
return call()
|
||||
except (httpx.ConnectError, httpx.ReadTimeout) as retry_error:
|
||||
raise original_error from retry_error
|
||||
|
||||
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call tool via STDIO protocol using persistent session."""
|
||||
if not self._session:
|
||||
@@ -433,18 +551,24 @@ class MCPClient:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
"""Reconnect to the configured MCP server."""
|
||||
logger.info(f"Reconnecting to MCP server '{self.config.name}'...")
|
||||
self.disconnect()
|
||||
self.connect()
|
||||
|
||||
_CLEANUP_TIMEOUT = 10
|
||||
_THREAD_JOIN_TIMEOUT = 12
|
||||
|
||||
async def _cleanup_stdio_async(self) -> None:
|
||||
"""Async cleanup for STDIO session and context managers.
|
||||
"""Async cleanup for persistent MCP session and context managers.
|
||||
|
||||
Cleanup order is critical:
|
||||
- The session must be closed BEFORE the stdio_context because the session
|
||||
depends on the streams provided by stdio_context.
|
||||
- This mirrors the initialization order in _connect_stdio(), where
|
||||
stdio_context is entered first (providing streams), then the session is
|
||||
created with those streams and entered.
|
||||
- The session must be closed BEFORE the transport context manager because the
|
||||
session depends on the streams provided by that context.
|
||||
- This mirrors the initialization order in _connect_stdio() / _connect_sse(),
|
||||
where the transport context is entered first (providing streams), then the
|
||||
session is created with those streams and entered.
|
||||
- Do not change this ordering without carefully considering these dependencies.
|
||||
"""
|
||||
# First: close session (depends on stdio_context streams)
|
||||
@@ -477,6 +601,16 @@ class MCPClient:
|
||||
finally:
|
||||
self._stdio_context = None
|
||||
|
||||
try:
|
||||
if self._sse_context:
|
||||
await self._sse_context.__aexit__(None, None, None)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("SSE context cleanup was cancelled; proceeding with best-effort shutdown")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing SSE context: {e}")
|
||||
finally:
|
||||
self._sse_context = None
|
||||
|
||||
# Third: close errlog file handle if we opened one
|
||||
if self._errlog_handle is not None:
|
||||
try:
|
||||
@@ -552,6 +686,7 @@ class MCPClient:
|
||||
# Setting None to None is safe and ensures clean state.
|
||||
self._session = None
|
||||
self._stdio_context = None
|
||||
self._sse_context = None
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._loop = None
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Shared MCP client connection management."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPConnectionManager:
|
||||
"""Process-wide MCP client pool keyed by server name."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool: dict[str, MCPClient] = {}
|
||||
self._refcounts: dict[str, int] = {}
|
||||
self._configs: dict[str, MCPServerConfig] = {}
|
||||
self._pool_lock = threading.Lock()
|
||||
# Transition events keep callers from racing a connect/reconnect/disconnect.
|
||||
self._transitions: dict[str, threading.Event] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPConnectionManager":
|
||||
"""Return the process-level singleton instance."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
def _is_connected(client: MCPClient | None) -> bool:
|
||||
return bool(client and getattr(client, "_connected", False))
|
||||
|
||||
def acquire(self, config: MCPServerConfig) -> MCPClient:
|
||||
"""Get or create a shared connection and increment its refcount."""
|
||||
server_name = config.name
|
||||
|
||||
while True:
|
||||
should_connect = False
|
||||
transition_event: threading.Event | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
client = self._pool.get(server_name)
|
||||
if self._is_connected(client) and server_name not in self._transitions:
|
||||
new_refcount = self._refcounts.get(server_name, 0) + 1
|
||||
self._refcounts[server_name] = new_refcount
|
||||
self._configs[server_name] = config
|
||||
logger.debug(
|
||||
"Reusing pooled connection for MCP server '%s' (refcount=%d)",
|
||||
server_name,
|
||||
new_refcount,
|
||||
)
|
||||
return client
|
||||
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
self._configs[server_name] = config
|
||||
should_connect = True
|
||||
|
||||
if not should_connect:
|
||||
transition_event.wait()
|
||||
continue
|
||||
|
||||
client = MCPClient(config)
|
||||
try:
|
||||
client.connect()
|
||||
except Exception:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
if (
|
||||
server_name not in self._pool
|
||||
and self._refcounts.get(server_name, 0) <= 0
|
||||
):
|
||||
self._configs.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool[server_name] = client
|
||||
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
|
||||
self._configs[server_name] = config
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return client
|
||||
|
||||
client.disconnect()
|
||||
|
||||
def release(self, server_name: str) -> None:
|
||||
"""Decrement refcount and disconnect when the last user releases."""
|
||||
while True:
|
||||
disconnect_client: MCPClient | None = None
|
||||
transition_event: threading.Event | None = None
|
||||
should_disconnect = False
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
refcount = self._refcounts.get(server_name, 0)
|
||||
if refcount <= 0:
|
||||
return
|
||||
if refcount > 1:
|
||||
self._refcounts[server_name] = refcount - 1
|
||||
return
|
||||
|
||||
disconnect_client = self._pool.pop(server_name, None)
|
||||
self._refcounts.pop(server_name, None)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
should_disconnect = True
|
||||
|
||||
if not should_disconnect:
|
||||
transition_event.wait()
|
||||
continue
|
||||
|
||||
try:
|
||||
if disconnect_client is not None:
|
||||
disconnect_client.disconnect()
|
||||
finally:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return
|
||||
|
||||
def health_check(self, server_name: str) -> bool:
|
||||
"""Return True when the pooled connection appears healthy."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
client = self._pool.get(server_name)
|
||||
config = self._configs.get(server_name)
|
||||
break
|
||||
|
||||
transition_event.wait()
|
||||
|
||||
if client is None or config is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if config.transport == "stdio":
|
||||
client.list_tools()
|
||||
return True
|
||||
|
||||
if not config.url:
|
||||
return False
|
||||
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"base_url": config.url,
|
||||
"headers": config.headers,
|
||||
"timeout": 5.0,
|
||||
}
|
||||
if config.transport == "unix":
|
||||
if not config.socket_path:
|
||||
return False
|
||||
client_kwargs["transport"] = httpx.HTTPTransport(uds=config.socket_path)
|
||||
|
||||
with httpx.Client(**client_kwargs) as http_client:
|
||||
response = http_client.get("/health")
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def reconnect(self, server_name: str) -> MCPClient:
|
||||
"""Force a disconnect and replace the pooled client with a fresh one."""
|
||||
while True:
|
||||
transition_event: threading.Event | None = None
|
||||
old_client: MCPClient | None = None
|
||||
|
||||
with self._pool_lock:
|
||||
transition_event = self._transitions.get(server_name)
|
||||
if transition_event is None:
|
||||
config = self._configs.get(server_name)
|
||||
if config is None:
|
||||
raise KeyError(f"Unknown MCP server: {server_name}")
|
||||
old_client = self._pool.get(server_name)
|
||||
refcount = self._refcounts.get(server_name, 0)
|
||||
transition_event = threading.Event()
|
||||
self._transitions[server_name] = transition_event
|
||||
break
|
||||
|
||||
transition_event.wait()
|
||||
|
||||
if old_client is not None:
|
||||
old_client.disconnect()
|
||||
|
||||
new_client = MCPClient(config)
|
||||
try:
|
||||
new_client.connect()
|
||||
except Exception:
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool.pop(server_name, None)
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
raise
|
||||
|
||||
with self._pool_lock:
|
||||
current = self._transitions.get(server_name)
|
||||
if current is transition_event:
|
||||
self._pool[server_name] = new_client
|
||||
self._refcounts[server_name] = max(refcount, 1)
|
||||
self._transitions.pop(server_name, None)
|
||||
transition_event.set()
|
||||
return new_client
|
||||
|
||||
new_client.disconnect()
|
||||
return self.acquire(config)
|
||||
|
||||
def cleanup_all(self) -> None:
|
||||
"""Disconnect all pooled clients and clear manager state."""
|
||||
while True:
|
||||
with self._pool_lock:
|
||||
if self._transitions:
|
||||
pending = list(self._transitions.values())
|
||||
else:
|
||||
cleanup_events = {name: threading.Event() for name in self._pool}
|
||||
clients = list(self._pool.items())
|
||||
self._transitions.update(cleanup_events)
|
||||
self._pool.clear()
|
||||
self._refcounts.clear()
|
||||
self._configs.clear()
|
||||
break
|
||||
|
||||
for event in pending:
|
||||
event.wait()
|
||||
|
||||
for _server_name, client in clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with self._pool_lock:
|
||||
for server_name, event in cleanup_events.items():
|
||||
current = self._transitions.get(server_name)
|
||||
if current is event:
|
||||
self._transitions.pop(server_name, None)
|
||||
event.set()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Pre-load validation for agent graphs.
|
||||
|
||||
Runs structural and credential checks before MCP servers are spawned.
|
||||
Runs structural, credential, and skill-trust checks before MCP servers are spawned.
|
||||
Fails fast with actionable error messages.
|
||||
"""
|
||||
|
||||
@@ -169,6 +169,9 @@ def run_preload_validation(
|
||||
1. Graph structure (includes GCU subagent-only checks) — non-recoverable
|
||||
2. Credentials — potentially recoverable via interactive setup
|
||||
|
||||
Skill discovery and trust gating (AS-13) happen later in runner._setup()
|
||||
so they have access to agent-level skill configuration.
|
||||
|
||||
Raises PreloadValidationError for structural issues.
|
||||
Raises CredentialError for credential issues.
|
||||
"""
|
||||
|
||||
@@ -28,6 +28,7 @@ from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.runtime.runtime_log_store import RuntimeLogStore
|
||||
from framework.tools.flowchart_utils import generate_fallback_flowchart
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner.protocol import AgentMessage, CapabilityResponse
|
||||
@@ -959,6 +960,12 @@ class AgentRunner:
|
||||
|
||||
graph = GraphSpec(**graph_kwargs)
|
||||
|
||||
# Generate flowchart.json if missing (for template/legacy agents)
|
||||
generate_fallback_flowchart(graph, goal, agent_path)
|
||||
# Read skill configuration from agent module
|
||||
agent_default_skills = getattr(agent_module, "default_skills", None)
|
||||
agent_skills = getattr(agent_module, "skills", None)
|
||||
|
||||
# Read runtime config (webhook settings, etc.) if defined
|
||||
agent_runtime_config = getattr(agent_module, "runtime_config", None)
|
||||
|
||||
@@ -970,7 +977,7 @@ class AgentRunner:
|
||||
configure_fn = getattr(agent_module, "configure_for_account", None)
|
||||
list_accts_fn = getattr(agent_module, "list_connected_accounts", None)
|
||||
|
||||
return cls(
|
||||
runner = cls(
|
||||
agent_path=agent_path,
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
@@ -986,6 +993,10 @@ class AgentRunner:
|
||||
list_accounts=list_accts_fn,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
# Stash skill config for use in _setup()
|
||||
runner._agent_default_skills = agent_default_skills
|
||||
runner._agent_skills = agent_skills
|
||||
return runner
|
||||
|
||||
# Fallback: load from agent.json (legacy JSON-based agents)
|
||||
agent_json_path = agent_path / "agent.json"
|
||||
@@ -1003,7 +1014,10 @@ class AgentRunner:
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Invalid JSON in agent export file: {agent_json_path}") from exc
|
||||
|
||||
return cls(
|
||||
# Generate flowchart.json if missing (for legacy JSON-based agents)
|
||||
generate_fallback_flowchart(graph, goal, agent_path)
|
||||
|
||||
runner = cls(
|
||||
agent_path=agent_path,
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
@@ -1014,6 +1028,9 @@ class AgentRunner:
|
||||
skip_credential_validation=skip_credential_validation or False,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
runner._agent_default_skills = None
|
||||
runner._agent_skills = None
|
||||
return runner
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
@@ -1124,7 +1141,10 @@ class AgentRunner:
|
||||
|
||||
# Create LLM provider
|
||||
# Uses LiteLLM which auto-detects the provider from model name
|
||||
if self.mock_mode:
|
||||
# Skip if already injected (e.g. worker agents with a pre-built LLM)
|
||||
if self._llm is not None:
|
||||
pass # LLM already configured externally
|
||||
elif self.mock_mode:
|
||||
# Use mock LLM for testing without real API calls
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
@@ -1323,6 +1343,20 @@ class AgentRunner:
|
||||
except Exception:
|
||||
pass # Best-effort — agent works without account info
|
||||
|
||||
# Skill configuration — the runtime handles discovery, loading, trust-gating and
|
||||
# prompt rasterization. The runner just builds the config.
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
|
||||
skills_manager_config = SkillsManagerConfig(
|
||||
skills_config=SkillsConfig.from_agent_vars(
|
||||
default_skills=getattr(self, "_agent_default_skills", None),
|
||||
skills=getattr(self, "_agent_skills", None),
|
||||
),
|
||||
project_root=self.agent_path,
|
||||
interactive=self._interactive,
|
||||
)
|
||||
|
||||
self._setup_agent_runtime(
|
||||
tools,
|
||||
tool_executor,
|
||||
@@ -1330,6 +1364,7 @@ class AgentRunner:
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
)
|
||||
|
||||
def _get_api_key_env_var(self, model: str) -> str | None:
|
||||
@@ -1350,6 +1385,8 @@ class AgentRunner:
|
||||
return "MISTRAL_API_KEY"
|
||||
elif model_lower.startswith("groq/"):
|
||||
return "GROQ_API_KEY"
|
||||
elif model_lower.startswith("openrouter/"):
|
||||
return "OPENROUTER_API_KEY"
|
||||
elif self._is_local_model(model_lower):
|
||||
return None # Local models don't need an API key
|
||||
elif model_lower.startswith("azure/"):
|
||||
@@ -1364,6 +1401,8 @@ class AgentRunner:
|
||||
return "MINIMAX_API_KEY"
|
||||
elif model_lower.startswith("kimi/"):
|
||||
return "KIMI_API_KEY"
|
||||
elif model_lower.startswith("hive/"):
|
||||
return "HIVE_API_KEY"
|
||||
else:
|
||||
# Default: assume OpenAI-compatible
|
||||
return "OPENAI_API_KEY"
|
||||
@@ -1386,6 +1425,8 @@ class AgentRunner:
|
||||
cred_id = "minimax"
|
||||
elif model_lower.startswith("kimi/"):
|
||||
cred_id = "kimi"
|
||||
elif model_lower.startswith("hive/"):
|
||||
cred_id = "hive"
|
||||
# Add more mappings as providers are added to LLM_CREDENTIALS
|
||||
|
||||
if cred_id is None:
|
||||
@@ -1425,6 +1466,10 @@ class AgentRunner:
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus=None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
skills_manager_config=None,
|
||||
) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
entry_points = []
|
||||
@@ -1484,6 +1529,7 @@ class AgentRunner:
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
)
|
||||
|
||||
# Pass intro_message through for TUI display
|
||||
|
||||
@@ -54,6 +54,8 @@ class ToolRegistry:
|
||||
def __init__(self):
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
self._mcp_clients: list[Any] = [] # List of MCPClient instances
|
||||
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
|
||||
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
|
||||
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
|
||||
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
|
||||
# MCP resync tracking
|
||||
@@ -455,11 +457,23 @@ class ToolRegistry:
|
||||
|
||||
for server_config in server_list:
|
||||
server_config = self._resolve_mcp_server_config(server_config, base_dir)
|
||||
try:
|
||||
self.register_mcp_server(server_config)
|
||||
except Exception as e:
|
||||
name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{name}': {e}")
|
||||
for _attempt in range(2):
|
||||
try:
|
||||
self.register_mcp_server(server_config)
|
||||
break
|
||||
except Exception as e:
|
||||
name = server_config.get("name", "unknown")
|
||||
if _attempt == 0:
|
||||
logger.warning(
|
||||
"MCP server '%s' failed to register, retrying in 2s: %s",
|
||||
name,
|
||||
e,
|
||||
)
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
else:
|
||||
logger.warning("MCP server '%s' failed after retry: %s", name, e)
|
||||
|
||||
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
|
||||
self._mcp_cred_snapshot = self._snapshot_credentials()
|
||||
@@ -468,6 +482,7 @@ class ToolRegistry:
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
use_connection_manager: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
Register an MCP server and discover its tools.
|
||||
@@ -483,12 +498,14 @@ class ToolRegistry:
|
||||
- url: Server URL (for http)
|
||||
- headers: HTTP headers (for http)
|
||||
- description: Server description (optional)
|
||||
use_connection_manager: When True, reuse a shared client keyed by server name
|
||||
|
||||
Returns:
|
||||
Number of tools registered from this server
|
||||
"""
|
||||
try:
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
# Build config object
|
||||
config = MCPServerConfig(
|
||||
@@ -504,11 +521,18 @@ class ToolRegistry:
|
||||
)
|
||||
|
||||
# Create and connect client
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
if use_connection_manager:
|
||||
client = MCPConnectionManager.get_instance().acquire(config)
|
||||
else:
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
|
||||
# Store client for cleanup
|
||||
self._mcp_clients.append(client)
|
||||
client_id = id(client)
|
||||
self._mcp_client_servers[client_id] = config.name
|
||||
if use_connection_manager:
|
||||
self._mcp_managed_clients.add(client_id)
|
||||
|
||||
# Register each tool
|
||||
server_name = server_config["name"]
|
||||
@@ -708,12 +732,7 @@ class ToolRegistry:
|
||||
logger.info("%s — resyncing MCP servers", reason)
|
||||
|
||||
# 1. Disconnect existing MCP clients
|
||||
for client in self._mcp_clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client during resync: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._cleanup_mcp_clients("during resync")
|
||||
|
||||
# 2. Remove MCP-registered tools
|
||||
for name in self._mcp_tool_names:
|
||||
@@ -728,12 +747,28 @@ class ToolRegistry:
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all MCP client connections."""
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
def _cleanup_mcp_clients(self, context: str = "") -> None:
|
||||
"""Disconnect or release all tracked MCP clients for this registry."""
|
||||
if context:
|
||||
context = f" {context}"
|
||||
|
||||
for client in self._mcp_clients:
|
||||
client_id = id(client)
|
||||
server_name = self._mcp_client_servers.get(client_id, client.config.name)
|
||||
try:
|
||||
client.disconnect()
|
||||
if client_id in self._mcp_managed_clients:
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
|
||||
MCPConnectionManager.get_instance().release(server_name)
|
||||
else:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client: {e}")
|
||||
logger.warning(f"Error disconnecting MCP client{context}: {e}")
|
||||
self._mcp_clients.clear()
|
||||
self._mcp_client_servers.clear()
|
||||
self._mcp_managed_clients.clear()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup."""
|
||||
|
||||
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.goal import Goal
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,6 +133,11 @@ class AgentRuntime:
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus: "EventBus | None" = None,
|
||||
skills_manager_config: "SkillsManagerConfig | None" = None,
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent runtime.
|
||||
@@ -153,7 +159,16 @@ class AgentRuntime:
|
||||
event_bus: Optional external EventBus. If provided, the runtime shares
|
||||
this bus instead of creating its own. Used by SessionManager to
|
||||
share a single bus between queen, worker, and judge.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
protocols_prompt: Deprecated. Pre-rendered operational protocols.
|
||||
"""
|
||||
from framework.skills.manager import SkillsManager
|
||||
|
||||
self.graph = graph
|
||||
self.goal = goal
|
||||
self._config = config or AgentRuntimeConfig()
|
||||
@@ -161,6 +176,31 @@ class AgentRuntime:
|
||||
self._checkpoint_config = checkpoint_config
|
||||
self.accounts_prompt = accounts_prompt
|
||||
|
||||
# --- Skill lifecycle: runtime owns the SkillsManager ---
|
||||
if skills_manager_config is not None:
|
||||
# New path: config-driven, runtime handles loading
|
||||
self._skills_manager = SkillsManager(skills_manager_config)
|
||||
self._skills_manager.load()
|
||||
elif skills_catalog_prompt or protocols_prompt:
|
||||
# Legacy path: caller passed pre-rendered strings
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Passing pre-rendered skills_catalog_prompt/protocols_prompt "
|
||||
"is deprecated. Pass skills_manager_config instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._skills_manager = SkillsManager.from_precomputed(
|
||||
skills_catalog_prompt, protocols_prompt
|
||||
)
|
||||
else:
|
||||
# Bare constructor: auto-load defaults
|
||||
self._skills_manager = SkillsManager()
|
||||
self._skills_manager.load()
|
||||
|
||||
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
|
||||
|
||||
# Primary graph identity
|
||||
self._graph_id: str = graph_id or "primary"
|
||||
|
||||
@@ -216,6 +256,18 @@ class AgentRuntime:
|
||||
# Optional greeting shown to user on TUI load (set by AgentRunner)
|
||||
self.intro_message: str = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Skill prompt accessors (read by ExecutionStream constructors)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def skills_catalog_prompt(self) -> str:
|
||||
return self._skills_manager.skills_catalog_prompt
|
||||
|
||||
@property
|
||||
def protocols_prompt(self) -> str:
|
||||
return self._skills_manager.protocols_prompt
|
||||
|
||||
def register_entry_point(self, spec: EntryPointSpec) -> None:
|
||||
"""
|
||||
Register a named entry point for the agent.
|
||||
@@ -293,6 +345,9 @@ class AgentRuntime:
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
@@ -393,18 +448,24 @@ class AgentRuntime:
|
||||
|
||||
tc = spec.trigger_config
|
||||
cron_expr = tc.get("cron")
|
||||
interval = tc.get("interval_minutes")
|
||||
_raw_interval = tc.get("interval_minutes")
|
||||
interval = float(_raw_interval) if _raw_interval is not None else None
|
||||
run_immediately = tc.get("run_immediately", False)
|
||||
|
||||
if cron_expr:
|
||||
# Cron expression mode — takes priority over interval_minutes
|
||||
try:
|
||||
from croniter import croniter
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"croniter is required for cron-based entry points. "
|
||||
"Install it with: uv pip install croniter"
|
||||
) from e
|
||||
|
||||
# Validate the expression upfront
|
||||
try:
|
||||
if not croniter.is_valid(cron_expr):
|
||||
raise ValueError(f"Invalid cron expression: {cron_expr}")
|
||||
except (ImportError, ValueError) as e:
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Entry point '%s' has invalid cron config: %s",
|
||||
ep_id,
|
||||
@@ -544,7 +605,7 @@ class AgentRuntime:
|
||||
ep_id,
|
||||
cron_expr,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
idle_timeout=float(tc.get("idle_timeout_seconds", 300)),
|
||||
)()
|
||||
)
|
||||
self._timer_tasks.append(task)
|
||||
@@ -674,7 +735,7 @@ class AgentRuntime:
|
||||
ep_id,
|
||||
interval,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
idle_timeout=float(tc.get("idle_timeout_seconds", 300)),
|
||||
)()
|
||||
)
|
||||
self._timer_tasks.append(task)
|
||||
@@ -921,6 +982,9 @@ class AgentRuntime:
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
)
|
||||
if self._running:
|
||||
await stream.start()
|
||||
@@ -999,7 +1063,8 @@ class AgentRuntime:
|
||||
if spec.trigger_type != "timer":
|
||||
continue
|
||||
tc = spec.trigger_config
|
||||
interval = tc.get("interval_minutes")
|
||||
_raw_interval = tc.get("interval_minutes")
|
||||
interval = float(_raw_interval) if _raw_interval is not None else None
|
||||
run_immediately = tc.get("run_immediately", False)
|
||||
|
||||
if interval and interval > 0 and self._running:
|
||||
@@ -1144,7 +1209,7 @@ class AgentRuntime:
|
||||
ep_id,
|
||||
interval,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
idle_timeout=float(tc.get("idle_timeout_seconds", 300)),
|
||||
)()
|
||||
)
|
||||
timer_tasks.append(task)
|
||||
@@ -1699,6 +1764,11 @@ def create_agent_runtime(
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus: "EventBus | None" = None,
|
||||
skills_manager_config: "SkillsManagerConfig | None" = None,
|
||||
# Deprecated — pass skills_manager_config instead.
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
) -> AgentRuntime:
|
||||
"""
|
||||
Create and configure an AgentRuntime with entry points.
|
||||
@@ -1725,6 +1795,13 @@ def create_agent_runtime(
|
||||
accounts_data: Raw account data for per-node prompt generation.
|
||||
tool_provider_map: Tool name to provider name mapping for account routing.
|
||||
event_bus: Optional external EventBus to share with other components.
|
||||
skills_catalog_prompt: Available skills catalog for system prompt.
|
||||
protocols_prompt: Default skill operational protocols for system prompt.
|
||||
skill_dirs: Skill base directories for Tier 3 resource access.
|
||||
skills_manager_config: Skill configuration — the runtime owns
|
||||
discovery, loading, and prompt renderation internally.
|
||||
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
|
||||
protocols_prompt: Deprecated. Pre-rendered operational protocols.
|
||||
|
||||
Returns:
|
||||
Configured AgentRuntime (not yet started)
|
||||
@@ -1751,6 +1828,10 @@ def create_agent_runtime(
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
skills_catalog_prompt=skills_catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
skill_dirs=skill_dirs,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
|
||||
@@ -117,6 +117,7 @@ class EventType(StrEnum):
|
||||
|
||||
# Context management
|
||||
CONTEXT_COMPACTED = "context_compacted"
|
||||
CONTEXT_USAGE_UPDATED = "context_usage_updated"
|
||||
|
||||
# External triggers
|
||||
WEBHOOK_RECEIVED = "webhook_received"
|
||||
@@ -159,6 +160,7 @@ class EventType(StrEnum):
|
||||
TRIGGER_DEACTIVATED = "trigger_deactivated"
|
||||
TRIGGER_FIRED = "trigger_fired"
|
||||
TRIGGER_REMOVED = "trigger_removed"
|
||||
TRIGGER_UPDATED = "trigger_updated"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -262,7 +264,7 @@ class EventBus:
|
||||
self._session_log: IO[str] | None = None
|
||||
self._session_log_iteration_offset: int = 0
|
||||
# Accumulator for client_output_delta snapshots — flushed on llm_turn_complete.
|
||||
# Key: (stream_id, node_id, execution_id, iteration) → latest AgentEvent
|
||||
# Key: (stream_id, node_id, execution_id, iteration, inner_turn) → latest AgentEvent
|
||||
self._pending_output_snapshots: dict[tuple, AgentEvent] = {}
|
||||
|
||||
def set_session_log(self, path: Path, *, iteration_offset: int = 0) -> None:
|
||||
@@ -328,6 +330,7 @@ class EventBus:
|
||||
event.node_id,
|
||||
event.execution_id,
|
||||
event.data.get("iteration"),
|
||||
event.data.get("inner_turn", 0),
|
||||
)
|
||||
self._pending_output_snapshots[key] = event
|
||||
return
|
||||
@@ -361,7 +364,7 @@ class EventBus:
|
||||
to_flush: list[tuple] = []
|
||||
for key, _evt in self._pending_output_snapshots.items():
|
||||
if stream_id is not None:
|
||||
k_stream, k_node, k_exec, _ = key
|
||||
k_stream, k_node, k_exec, _, _ = key
|
||||
if k_stream != stream_id or k_node != node_id or k_exec != execution_id:
|
||||
continue
|
||||
to_flush.append(key)
|
||||
@@ -749,6 +752,7 @@ class EventBus:
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
inner_turn: int = 0,
|
||||
) -> None:
|
||||
"""Emit LLM text delta event."""
|
||||
await self.publish(
|
||||
@@ -757,7 +761,7 @@ class EventBus:
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
data={"content": content, "snapshot": snapshot, "inner_turn": inner_turn},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -873,9 +877,10 @@ class EventBus:
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
iteration: int | None = None,
|
||||
inner_turn: int = 0,
|
||||
) -> None:
|
||||
"""Emit client output delta event (client_facing=True nodes)."""
|
||||
data: dict = {"content": content, "snapshot": snapshot}
|
||||
data: dict = {"content": content, "snapshot": snapshot, "inner_turn": inner_turn}
|
||||
if iteration is not None:
|
||||
data["iteration"] = iteration
|
||||
await self.publish(
|
||||
|
||||
@@ -186,6 +186,9 @@ class ExecutionStream:
|
||||
accounts_prompt: str = "",
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -209,6 +212,9 @@ class ExecutionStream:
|
||||
accounts_prompt: Connected accounts block for system prompt injection
|
||||
accounts_data: Raw account data for per-node prompt generation
|
||||
tool_provider_map: Tool name to provider name mapping for account routing
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
@@ -230,6 +236,22 @@ class ExecutionStream:
|
||||
self._accounts_prompt = accounts_prompt
|
||||
self._accounts_data = accounts_data
|
||||
self._tool_provider_map = tool_provider_map
|
||||
self._skills_catalog_prompt = skills_catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
self._skill_dirs: list[str] = skill_dirs or []
|
||||
|
||||
_es_logger = logging.getLogger(__name__)
|
||||
if protocols_prompt:
|
||||
_es_logger.info(
|
||||
"ExecutionStream[%s] received protocols_prompt (%d chars)",
|
||||
stream_id,
|
||||
len(protocols_prompt),
|
||||
)
|
||||
else:
|
||||
_es_logger.warning(
|
||||
"ExecutionStream[%s] received EMPTY protocols_prompt",
|
||||
stream_id,
|
||||
)
|
||||
|
||||
# Create stream-scoped runtime
|
||||
self._runtime = StreamRuntime(
|
||||
@@ -675,6 +697,9 @@ class ExecutionStream:
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
skills_catalog_prompt=self._skills_catalog_prompt,
|
||||
protocols_prompt=self._protocols_prompt,
|
||||
skill_dirs=self._skill_dirs,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
@@ -8,6 +8,7 @@ write. Errors are silently swallowed — this must never break the agent.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
@@ -47,6 +48,9 @@ def log_llm_turn(
|
||||
Never raises.
|
||||
"""
|
||||
try:
|
||||
# Skip logging during test runs to avoid polluting real logs.
|
||||
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("HIVE_DISABLE_LLM_LOGS"):
|
||||
return
|
||||
global _log_file, _log_ready # noqa: PLW0603
|
||||
if not _log_ready:
|
||||
_log_file = _open_log()
|
||||
|
||||
@@ -47,25 +47,34 @@ class RuntimeLogStore:
|
||||
self._base_path = base_path
|
||||
# Note: _runs_dir is determined per-run_id by _get_run_dir()
|
||||
|
||||
def _session_logs_dir(self, run_id: str) -> Path:
|
||||
"""Return the unified session-backed logs directory for a run ID."""
|
||||
is_runtime_logs = self._base_path.name == "runtime_logs"
|
||||
root = self._base_path.parent if is_runtime_logs else self._base_path
|
||||
return root / "sessions" / run_id / "logs"
|
||||
|
||||
def _legacy_run_dir(self, run_id: str) -> Path:
|
||||
"""Return the deprecated standalone runs directory for a run ID."""
|
||||
return self._base_path / "runs" / run_id
|
||||
|
||||
def _get_run_dir(self, run_id: str) -> Path:
|
||||
"""Determine run directory path based on run_id format.
|
||||
|
||||
- New format (session_*): {storage_root}/sessions/{run_id}/logs/
|
||||
- Session-backed runs: {storage_root}/sessions/{run_id}/logs/
|
||||
- Old format (anything else): {base_path}/runs/{run_id}/ (deprecated)
|
||||
"""
|
||||
if run_id.startswith("session_"):
|
||||
is_runtime_logs = self._base_path.name == "runtime_logs"
|
||||
root = self._base_path.parent if is_runtime_logs else self._base_path
|
||||
return root / "sessions" / run_id / "logs"
|
||||
session_run_dir = self._session_logs_dir(run_id)
|
||||
if session_run_dir.exists() or run_id.startswith("session_"):
|
||||
return session_run_dir
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Reading logs from deprecated location for run_id={run_id}. "
|
||||
"New sessions use unified storage at sessions/session_*/logs/",
|
||||
"New sessions use unified storage at sessions/<session_id>/logs/",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return self._base_path / "runs" / run_id
|
||||
return self._legacy_run_dir(run_id)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Incremental write (sync — called from locked sections)
|
||||
@@ -76,6 +85,10 @@ class RuntimeLogStore:
|
||||
run_dir = self._get_run_dir(run_id)
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def ensure_session_run_dir(self, run_id: str) -> None:
|
||||
"""Create the unified session-backed log directory immediately."""
|
||||
self._session_logs_dir(run_id).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def append_step(self, run_id: str, step: NodeStepLog) -> None:
|
||||
"""Append one JSONL line to tool_logs.jsonl. Sync."""
|
||||
path = self._get_run_dir(run_id) / "tool_logs.jsonl"
|
||||
@@ -200,17 +213,17 @@ class RuntimeLogStore:
|
||||
run_ids = []
|
||||
|
||||
# Scan new location: base_path/sessions/{session_id}/logs/
|
||||
# Determine the correct base path for sessions
|
||||
is_runtime_logs = self._base_path.name == "runtime_logs"
|
||||
root = self._base_path.parent if is_runtime_logs else self._base_path
|
||||
sessions_dir = root / "sessions"
|
||||
|
||||
if sessions_dir.exists():
|
||||
for session_dir in sessions_dir.iterdir():
|
||||
if session_dir.is_dir() and session_dir.name.startswith("session_"):
|
||||
logs_dir = session_dir / "logs"
|
||||
if logs_dir.exists() and logs_dir.is_dir():
|
||||
run_ids.append(session_dir.name)
|
||||
if not session_dir.is_dir():
|
||||
continue
|
||||
logs_dir = session_dir / "logs"
|
||||
if logs_dir.exists() and logs_dir.is_dir():
|
||||
run_ids.append(session_dir.name)
|
||||
|
||||
# Scan old location: base_path/runs/ (deprecated)
|
||||
old_runs_dir = self._base_path / "runs"
|
||||
|
||||
@@ -66,15 +66,16 @@ class RuntimeLogger:
|
||||
"""
|
||||
if session_id:
|
||||
self._run_id = session_id
|
||||
self._store.ensure_session_run_dir(self._run_id)
|
||||
else:
|
||||
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S")
|
||||
short_uuid = uuid.uuid4().hex[:8]
|
||||
self._run_id = f"{ts}_{short_uuid}"
|
||||
self._store.ensure_run_dir(self._run_id)
|
||||
|
||||
self._goal_id = goal_id
|
||||
self._started_at = datetime.now(UTC).isoformat()
|
||||
self._logged_node_ids = set()
|
||||
self._store.ensure_run_dir(self._run_id)
|
||||
return self._run_id
|
||||
|
||||
def log_step(
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Tests for custom session-backed runtime logging paths."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runtime.runtime_log_store import RuntimeLogStore
|
||||
from framework.runtime.runtime_logger import RuntimeLogger
|
||||
|
||||
|
||||
def test_graph_executor_uses_custom_session_dir_name_for_runtime_logs():
|
||||
executor = GraphExecutor(
|
||||
runtime=MagicMock(),
|
||||
storage_path=Path("/tmp/test-agent/sessions/my-custom-session"),
|
||||
)
|
||||
|
||||
assert executor._get_runtime_log_session_id() == "my-custom-session"
|
||||
|
||||
|
||||
def test_runtime_logger_creates_session_log_dir_for_custom_session_id(tmp_path):
|
||||
base = tmp_path / ".hive" / "agents" / "test_agent"
|
||||
base.mkdir(parents=True)
|
||||
store = RuntimeLogStore(base)
|
||||
logger = RuntimeLogger(store=store, agent_id="test-agent")
|
||||
|
||||
run_id = logger.start_run(goal_id="goal-1", session_id="my-custom-session")
|
||||
|
||||
assert run_id == "my-custom-session"
|
||||
assert (base / "sessions" / "my-custom-session" / "logs").is_dir()
|
||||
@@ -69,6 +69,7 @@ async def create_queen(
|
||||
QueenPhaseState,
|
||||
register_queen_lifecycle_tools,
|
||||
)
|
||||
from framework.tools.queen_memory_tools import register_queen_memory_tools
|
||||
|
||||
hive_home = Path.home() / ".hive"
|
||||
|
||||
@@ -122,6 +123,9 @@ async def create_queen(
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
# ---- Episodic memory tools (always registered) ---------------------
|
||||
register_queen_memory_tools(queen_registry)
|
||||
|
||||
# ---- Monitoring tools (only when worker is loaded) ----------------
|
||||
if session.worker_runtime:
|
||||
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
|
||||
@@ -216,6 +220,16 @@ async def create_queen(
|
||||
+ worker_identity
|
||||
)
|
||||
|
||||
# ---- Default skill protocols -------------------------------------
|
||||
try:
|
||||
from framework.skills.manager import SkillsManager
|
||||
|
||||
_queen_skills_mgr = SkillsManager()
|
||||
_queen_skills_mgr.load()
|
||||
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
|
||||
except Exception:
|
||||
logger.debug("Queen skill loading failed (non-fatal)", exc_info=True)
|
||||
|
||||
# ---- Persona hook ------------------------------------------------
|
||||
_session_llm = session.llm
|
||||
_session_event_bus = session.event_bus
|
||||
|
||||
@@ -103,7 +103,9 @@ async def handle_delete_credential(request: web.Request) -> web.Response:
|
||||
if credential_id == "aden_api_key":
|
||||
from framework.credentials.key_storage import delete_aden_api_key
|
||||
|
||||
delete_aden_api_key()
|
||||
deleted = delete_aden_api_key()
|
||||
if not deleted:
|
||||
return web.json_response({"error": "Credential 'aden_api_key' not found"}, status=404)
|
||||
return web.json_response({"deleted": True})
|
||||
|
||||
store = _get_store(request)
|
||||
@@ -178,7 +180,10 @@ async def handle_check_agent(request: web.Request) -> web.Response:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error checking agent credentials: {e}")
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
return web.json_response(
|
||||
{"error": "Internal server error while checking credentials"},
|
||||
status=500,
|
||||
)
|
||||
|
||||
|
||||
def _status_to_dict(c) -> dict:
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
from aiohttp import web
|
||||
from aiohttp.client_exceptions import ClientConnectionResetError as _AiohttpConnReset
|
||||
|
||||
from framework.runtime.event_bus import EventType
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.server.app import resolve_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,6 +37,7 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.NODE_RETRY,
|
||||
EventType.NODE_TOOL_DOOM_LOOP,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.CONTEXT_USAGE_UPDATED,
|
||||
EventType.WORKER_LOADED,
|
||||
EventType.CREDENTIALS_REQUIRED,
|
||||
EventType.SUBAGENT_REPORT,
|
||||
@@ -46,6 +47,7 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.TRIGGER_DEACTIVATED,
|
||||
EventType.TRIGGER_FIRED,
|
||||
EventType.TRIGGER_REMOVED,
|
||||
EventType.TRIGGER_UPDATED,
|
||||
EventType.DRAFT_GRAPH_UPDATED,
|
||||
]
|
||||
|
||||
@@ -165,6 +167,54 @@ async def handle_events(request: web.Request) -> web.StreamResponse:
|
||||
if replayed:
|
||||
logger.info("SSE replayed %d buffered events for session='%s'", replayed, session.id)
|
||||
|
||||
# Inject a live-status snapshot so the frontend knows which nodes are
|
||||
# currently running. This covers the case where the user navigated away
|
||||
# and back — the localStorage snapshot is stale, and the ring-buffer
|
||||
# replay may not include the original node_loop_started events.
|
||||
worker_runtime = getattr(session, "worker_runtime", None)
|
||||
if worker_runtime and getattr(worker_runtime, "is_running", False):
|
||||
try:
|
||||
for stream_info in worker_runtime.get_active_streams():
|
||||
graph_id = stream_info.get("graph_id")
|
||||
stream_id = stream_info.get("stream_id", "default")
|
||||
for exec_id in stream_info.get("active_execution_ids", []):
|
||||
# Synthesize execution_started so frontend sets workerRunState
|
||||
synth_exec = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id=stream_id,
|
||||
execution_id=exec_id,
|
||||
graph_id=graph_id,
|
||||
data={"synthetic": True},
|
||||
).to_dict()
|
||||
try:
|
||||
queue.put_nowait(synth_exec)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
# Find the currently executing node via the executor
|
||||
for _gid, reg in worker_runtime._graphs.items():
|
||||
if _gid != graph_id:
|
||||
continue
|
||||
for _ep_id, stream in reg.streams.items():
|
||||
for exec_id, executor in stream._active_executors.items():
|
||||
current = getattr(executor, "current_node_id", None)
|
||||
if current:
|
||||
synth_node = AgentEvent(
|
||||
type=EventType.NODE_LOOP_STARTED,
|
||||
stream_id=stream_id,
|
||||
node_id=current,
|
||||
execution_id=exec_id,
|
||||
graph_id=graph_id,
|
||||
data={"synthetic": True},
|
||||
).to_dict()
|
||||
try:
|
||||
queue.put_nowait(synth_node)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
logger.info("SSE injected live-status snapshot for session='%s'", session.id)
|
||||
except Exception:
|
||||
logger.debug("Failed to inject live-status snapshot", exc_info=True)
|
||||
|
||||
event_count = 0
|
||||
close_reason = "unknown"
|
||||
try:
|
||||
|
||||
@@ -11,7 +11,6 @@ Session-primary routes:
|
||||
- GET /api/sessions/{session_id}/entry-points — list entry points
|
||||
- PATCH /api/sessions/{session_id}/triggers/{id} — update trigger task
|
||||
- GET /api/sessions/{session_id}/graphs — list graph IDs
|
||||
- GET /api/sessions/{session_id}/queen-messages — queen conversation history
|
||||
- GET /api/sessions/{session_id}/events/history — persisted eventbus log (for replay)
|
||||
|
||||
Worker session browsing (persisted execution runs on disk):
|
||||
@@ -24,6 +23,8 @@ Worker session browsing (persisted execution runs on disk):
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
@@ -64,7 +65,9 @@ def _session_to_live_dict(session) -> dict:
|
||||
"loaded_at": session.loaded_at,
|
||||
"uptime_seconds": round(time.time() - session.loaded_at, 1),
|
||||
"intro_message": getattr(session.runner, "intro_message", "") or "",
|
||||
"queen_phase": phase_state.phase if phase_state else "planning",
|
||||
"queen_phase": phase_state.phase
|
||||
if phase_state
|
||||
else ("staging" if session.worker_runtime else "planning"),
|
||||
}
|
||||
|
||||
|
||||
@@ -406,7 +409,7 @@ async def handle_session_entry_points(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
async def handle_update_trigger_task(request: web.Request) -> web.Response:
|
||||
"""PATCH /api/sessions/{session_id}/triggers/{trigger_id} — update trigger task."""
|
||||
"""PATCH /api/sessions/{session_id}/triggers/{trigger_id} — update trigger fields."""
|
||||
session, err = resolve_session(request)
|
||||
if err:
|
||||
return err
|
||||
@@ -425,30 +428,136 @@ async def handle_update_trigger_task(request: web.Request) -> web.Response:
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
|
||||
task = body.get("task")
|
||||
if task is None:
|
||||
return web.json_response({"error": "Missing 'task' field"}, status=400)
|
||||
if not isinstance(task, str):
|
||||
return web.json_response({"error": "'task' must be a string"}, status=400)
|
||||
updates: dict[str, object] = {}
|
||||
|
||||
tdef.task = task
|
||||
if "task" in body:
|
||||
task = body.get("task")
|
||||
if not isinstance(task, str):
|
||||
return web.json_response({"error": "'task' must be a string"}, status=400)
|
||||
tdef.task = task
|
||||
updates["task"] = tdef.task
|
||||
|
||||
trigger_config_update = body.get("trigger_config")
|
||||
if trigger_config_update is not None:
|
||||
if not isinstance(trigger_config_update, dict):
|
||||
return web.json_response(
|
||||
{"error": "'trigger_config' must be an object"},
|
||||
status=400,
|
||||
)
|
||||
merged_trigger_config = dict(tdef.trigger_config)
|
||||
merged_trigger_config.update(trigger_config_update)
|
||||
|
||||
if tdef.trigger_type == "timer":
|
||||
cron_expr = merged_trigger_config.get("cron")
|
||||
interval = merged_trigger_config.get("interval_minutes")
|
||||
if cron_expr is not None and not isinstance(cron_expr, str):
|
||||
return web.json_response(
|
||||
{"error": "'trigger_config.cron' must be a string"},
|
||||
status=400,
|
||||
)
|
||||
if cron_expr:
|
||||
try:
|
||||
from croniter import croniter
|
||||
|
||||
if not croniter.is_valid(cron_expr):
|
||||
return web.json_response(
|
||||
{"error": f"Invalid cron expression: {cron_expr}"},
|
||||
status=400,
|
||||
)
|
||||
except ImportError:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": (
|
||||
"croniter package not installed — cannot validate cron expression."
|
||||
)
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
merged_trigger_config.pop("interval_minutes", None)
|
||||
elif interval is None:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": (
|
||||
"Timer trigger needs 'cron' or 'interval_minutes' in trigger_config."
|
||||
)
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
elif not isinstance(interval, (int, float)) or interval <= 0:
|
||||
return web.json_response(
|
||||
{"error": "'trigger_config.interval_minutes' must be > 0"},
|
||||
status=400,
|
||||
)
|
||||
tdef.trigger_config = merged_trigger_config
|
||||
updates["trigger_config"] = tdef.trigger_config
|
||||
|
||||
if not updates:
|
||||
return web.json_response(
|
||||
{"error": "Provide at least one of 'task' or 'trigger_config'"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Persist to session state and agent definition
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_persist_active_triggers,
|
||||
_save_trigger_to_agent,
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
if "trigger_config" in updates and trigger_id in getattr(session, "active_trigger_ids", set()):
|
||||
task = session.active_timer_tasks.pop(trigger_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
getattr(session, "trigger_next_fire", {}).pop(trigger_id, None)
|
||||
|
||||
webhook_subs = getattr(session, "active_webhook_subs", {})
|
||||
if sub_id := webhook_subs.pop(trigger_id, None):
|
||||
with contextlib.suppress(Exception):
|
||||
session.event_bus.unsubscribe(sub_id)
|
||||
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, trigger_id, tdef)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, trigger_id, tdef)
|
||||
|
||||
if trigger_id in getattr(session, "active_trigger_ids", set()):
|
||||
session_id = request.match_info["session_id"]
|
||||
await _persist_active_triggers(session, session_id)
|
||||
|
||||
_save_trigger_to_agent(session, trigger_id, tdef)
|
||||
|
||||
# Emit SSE event so the frontend updates the graph and detail panel
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_UPDATED,
|
||||
stream_id="queen",
|
||||
data={
|
||||
"trigger_id": trigger_id,
|
||||
"task": tdef.task,
|
||||
"trigger_config": tdef.trigger_config,
|
||||
"trigger_type": tdef.trigger_type,
|
||||
"name": tdef.description or trigger_id,
|
||||
"entry_node": getattr(
|
||||
getattr(getattr(session, "runner", None), "graph", None),
|
||||
"entry_node",
|
||||
None,
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"trigger_id": trigger_id,
|
||||
"task": tdef.task,
|
||||
"trigger_config": tdef.trigger_config,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -492,12 +601,14 @@ async def handle_list_worker_sessions(request: web.Request) -> web.Response:
|
||||
|
||||
sessions = []
|
||||
for d in sorted(sess_dir.iterdir(), reverse=True):
|
||||
if not d.is_dir() or not d.name.startswith("session_"):
|
||||
if not d.is_dir():
|
||||
continue
|
||||
state_path = d / "state.json"
|
||||
if not d.name.startswith("session_") and not state_path.exists():
|
||||
continue
|
||||
|
||||
entry: dict = {"session_id": d.name}
|
||||
|
||||
state_path = d / "state.json"
|
||||
if state_path.exists():
|
||||
try:
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
@@ -750,60 +861,6 @@ async def handle_messages(request: web.Request) -> web.Response:
|
||||
return web.json_response({"messages": all_messages})
|
||||
|
||||
|
||||
async def handle_queen_messages(request: web.Request) -> web.Response:
|
||||
"""GET /api/sessions/{session_id}/queen-messages — get queen conversation.
|
||||
|
||||
Reads directly from disk so it works for both live sessions and cold
|
||||
(post-server-restart) sessions — no live session required.
|
||||
"""
|
||||
session_id = request.match_info["session_id"]
|
||||
|
||||
queen_dir = Path.home() / ".hive" / "queen" / "session" / session_id
|
||||
convs_dir = queen_dir / "conversations"
|
||||
if not convs_dir.exists():
|
||||
return web.json_response({"messages": [], "session_id": session_id})
|
||||
|
||||
all_messages: list[dict] = []
|
||||
|
||||
def _read_parts(parts_dir: Path, node_id: str) -> None:
|
||||
if not parts_dir.exists():
|
||||
return
|
||||
for part_file in sorted(parts_dir.iterdir()):
|
||||
if part_file.suffix != ".json":
|
||||
continue
|
||||
try:
|
||||
part = json.loads(part_file.read_text(encoding="utf-8"))
|
||||
part["_node_id"] = node_id
|
||||
# Use file mtime as created_at so frontend can order
|
||||
# queen and worker messages chronologically.
|
||||
part.setdefault("created_at", part_file.stat().st_mtime)
|
||||
all_messages.append(part)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
# Flat layout: conversations/parts/*.json
|
||||
_read_parts(convs_dir / "parts", "queen")
|
||||
|
||||
# Node-based layout: conversations/<node_id>/parts/*.json
|
||||
for node_dir in convs_dir.iterdir():
|
||||
if not node_dir.is_dir() or node_dir.name == "parts":
|
||||
continue
|
||||
_read_parts(node_dir / "parts", node_dir.name)
|
||||
|
||||
all_messages.sort(key=lambda m: m.get("created_at", m.get("seq", 0)))
|
||||
|
||||
# Filter to client-facing messages only
|
||||
all_messages = [
|
||||
m
|
||||
for m in all_messages
|
||||
if not m.get("is_transition_marker")
|
||||
and m["role"] != "tool"
|
||||
and not (m["role"] == "assistant" and m.get("tool_calls"))
|
||||
]
|
||||
|
||||
return web.json_response({"messages": all_messages, "session_id": session_id})
|
||||
|
||||
|
||||
async def handle_session_events_history(request: web.Request) -> web.Response:
|
||||
"""GET /api/sessions/{session_id}/events/history — persisted eventbus log.
|
||||
|
||||
@@ -951,7 +1008,7 @@ def register_routes(app: web.Application) -> None:
|
||||
"/api/sessions/{session_id}/triggers/{trigger_id}", handle_update_trigger_task
|
||||
)
|
||||
app.router.add_get("/api/sessions/{session_id}/graphs", handle_session_graphs)
|
||||
app.router.add_get("/api/sessions/{session_id}/queen-messages", handle_queen_messages)
|
||||
|
||||
app.router.add_get("/api/sessions/{session_id}/events/history", handle_session_events_history)
|
||||
|
||||
# Worker session browsing (session-primary)
|
||||
|
||||
@@ -47,6 +47,8 @@ class Session:
|
||||
worker_handoff_sub: str | None = None
|
||||
# Memory consolidation subscription (fires on CONTEXT_COMPACTED)
|
||||
memory_consolidation_sub: str | None = None
|
||||
# Worker run digest subscription (fires on EXECUTION_COMPLETED / EXECUTION_FAILED)
|
||||
worker_digest_sub: str | None = None
|
||||
# Trigger definitions loaded from agent's triggers.json (available but inactive)
|
||||
available_triggers: dict[str, TriggerDefinition] = field(default_factory=dict)
|
||||
# Active trigger tracking (IDs currently firing + their asyncio tasks)
|
||||
@@ -177,6 +179,31 @@ class SessionManager:
|
||||
agent_path = Path(agent_path)
|
||||
resolved_worker_id = agent_id or agent_path.name
|
||||
|
||||
# When cold-restoring, check meta.json for the phase — if the agent
|
||||
# was still being built we must NOT try to load the worker (the code
|
||||
# is incomplete and will fail to import).
|
||||
if queen_resume_from:
|
||||
_resume_phase = None
|
||||
_meta_path = (
|
||||
Path.home() / ".hive" / "queen" / "session" / queen_resume_from / "meta.json"
|
||||
)
|
||||
if _meta_path.exists():
|
||||
try:
|
||||
_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
|
||||
_resume_phase = _meta.get("phase")
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
if _resume_phase in ("building", "planning"):
|
||||
# Fall back to queen-only session — cold resume handler in
|
||||
# _start_queen will set phase_state.agent_path and switch to
|
||||
# the correct phase.
|
||||
return await self.create_session(
|
||||
session_id=session_id,
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
)
|
||||
|
||||
# Reuse the original session ID when cold-restoring so the frontend
|
||||
# sees one continuous session instead of a new one each time.
|
||||
session = await self._create_session_core(
|
||||
@@ -193,6 +220,9 @@ class SessionManager:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Restore active triggers from persisted state (cold restore)
|
||||
await self._restore_active_triggers(session, session.id)
|
||||
|
||||
# Start queen with worker profile + lifecycle + monitoring tools
|
||||
worker_identity = (
|
||||
build_worker_profile(session.worker_runtime, agent_path=agent_path)
|
||||
@@ -204,7 +234,23 @@ class SessionManager:
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# If anything fails, tear down the session
|
||||
if queen_resume_from:
|
||||
# Cold restore: worker load failed (e.g. incomplete code from a
|
||||
# building session). Fall back to queen-only so the user can
|
||||
# continue the conversation and fix / rebuild the agent.
|
||||
logger.warning(
|
||||
"Cold restore: worker load failed for '%s', falling back to queen-only",
|
||||
agent_path,
|
||||
exc_info=True,
|
||||
)
|
||||
await self.stop_session(session.id)
|
||||
return await self.create_session(
|
||||
session_id=session_id,
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
)
|
||||
# If anything fails (non-cold-restore), tear down the session
|
||||
await self.stop_session(session.id)
|
||||
raise
|
||||
return session
|
||||
@@ -241,7 +287,17 @@ class SessionManager:
|
||||
try:
|
||||
# Blocking I/O — load in executor
|
||||
loop = asyncio.get_running_loop()
|
||||
resolved_model = model or self._model
|
||||
|
||||
# Prioritize: explicit model arg > worker-specific model > session default
|
||||
from framework.config import (
|
||||
get_preferred_worker_model,
|
||||
get_worker_api_base,
|
||||
get_worker_api_key,
|
||||
get_worker_llm_extra_kwargs,
|
||||
)
|
||||
|
||||
worker_model = get_preferred_worker_model()
|
||||
resolved_model = model or worker_model or self._model
|
||||
runner = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: AgentRunner.load(
|
||||
@@ -253,6 +309,22 @@ class SessionManager:
|
||||
),
|
||||
)
|
||||
|
||||
# If a worker-specific model is configured, build an LLM provider
|
||||
# with the correct worker credentials so _setup() doesn't fall back
|
||||
# to the queen's llm config (which may be a different provider).
|
||||
if worker_model and not model:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
worker_api_key = get_worker_api_key()
|
||||
worker_api_base = get_worker_api_base()
|
||||
worker_extra = get_worker_llm_extra_kwargs()
|
||||
runner._llm = LiteLLMProvider(
|
||||
model=resolved_model,
|
||||
api_key=worker_api_key,
|
||||
api_base=worker_api_base,
|
||||
**worker_extra,
|
||||
)
|
||||
|
||||
# Setup with session's event bus
|
||||
if runner._agent_runtime is None:
|
||||
await loop.run_in_executor(
|
||||
@@ -297,6 +369,9 @@ class SessionManager:
|
||||
session.worker_runtime = runtime
|
||||
session.worker_info = info
|
||||
|
||||
# Subscribe to execution completion for per-run digest generation
|
||||
self._subscribe_worker_digest(session)
|
||||
|
||||
async with self._lock:
|
||||
self._loading.discard(session.id)
|
||||
|
||||
@@ -399,6 +474,51 @@ class SessionManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _restore_active_triggers(self, session: "Session", session_id: str) -> None:
|
||||
"""Restore previously active triggers from persisted session state.
|
||||
|
||||
Called after worker loading to restart any timer/webhook triggers
|
||||
that were active before a server restart.
|
||||
"""
|
||||
if not session.available_triggers or not session.worker_runtime:
|
||||
return
|
||||
try:
|
||||
store = session.worker_runtime._session_store
|
||||
state = await store.read_state(session_id)
|
||||
if state and state.active_triggers:
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
|
||||
for tid in state.active_triggers:
|
||||
tdef = session.available_triggers.get(tid)
|
||||
if tdef:
|
||||
# Restore user-configured task override
|
||||
saved_task = saved_tasks.get(tid, "")
|
||||
if saved_task:
|
||||
tdef.task = saved_task
|
||||
tdef.active = True
|
||||
session.active_trigger_ids.add(tid)
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, tid, tdef)
|
||||
logger.info("Restored trigger timer '%s'", tid)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, tid, tdef)
|
||||
logger.info("Restored webhook trigger '%s'", tid)
|
||||
else:
|
||||
logger.warning(
|
||||
"Saved trigger '%s' not found in worker entry points, skipping",
|
||||
tid,
|
||||
)
|
||||
|
||||
# Restore worker_configured flag
|
||||
if state and getattr(state, "worker_configured", False):
|
||||
session.worker_configured = True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore active triggers: %s", e)
|
||||
|
||||
async def load_worker(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -447,44 +567,7 @@ class SessionManager:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Restore previously active triggers from persisted session state
|
||||
if session.available_triggers and session.worker_runtime:
|
||||
try:
|
||||
store = session.worker_runtime._session_store
|
||||
state = await store.read_state(session_id)
|
||||
if state and state.active_triggers:
|
||||
from framework.tools.queen_lifecycle_tools import (
|
||||
_start_trigger_timer,
|
||||
_start_trigger_webhook,
|
||||
)
|
||||
|
||||
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
|
||||
for tid in state.active_triggers:
|
||||
tdef = session.available_triggers.get(tid)
|
||||
if tdef:
|
||||
# Restore user-configured task override
|
||||
saved_task = saved_tasks.get(tid, "")
|
||||
if saved_task:
|
||||
tdef.task = saved_task
|
||||
tdef.active = True
|
||||
session.active_trigger_ids.add(tid)
|
||||
if tdef.trigger_type == "timer":
|
||||
await _start_trigger_timer(session, tid, tdef)
|
||||
logger.info("Restored trigger timer '%s'", tid)
|
||||
elif tdef.trigger_type == "webhook":
|
||||
await _start_trigger_webhook(session, tid, tdef)
|
||||
logger.info("Restored webhook trigger '%s'", tid)
|
||||
else:
|
||||
logger.warning(
|
||||
"Saved trigger '%s' not found in worker entry points, skipping",
|
||||
tid,
|
||||
)
|
||||
|
||||
# Restore worker_configured flag
|
||||
if state and getattr(state, "worker_configured", False):
|
||||
session.worker_configured = True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore active triggers: %s", e)
|
||||
await self._restore_active_triggers(session, session_id)
|
||||
|
||||
# Emit SSE event so the frontend can update UI
|
||||
await self._emit_worker_loaded(session)
|
||||
@@ -526,6 +609,13 @@ class SessionManager:
|
||||
await self._emit_trigger_events(session, "removed", session.available_triggers)
|
||||
session.available_triggers.clear()
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
worker_id = session.worker_id
|
||||
session.worker_id = None
|
||||
session.worker_path = None
|
||||
@@ -563,6 +653,13 @@ class SessionManager:
|
||||
pass
|
||||
session.worker_handoff_sub = None
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
# Stop queen and memory consolidation subscription
|
||||
if session.memory_consolidation_sub is not None:
|
||||
try:
|
||||
@@ -647,6 +744,135 @@ class SessionManager:
|
||||
else:
|
||||
logger.warning("Worker handoff received but queen node not ready")
|
||||
|
||||
def _subscribe_worker_digest(self, session: Session) -> None:
|
||||
"""Subscribe to worker events to write per-run digests.
|
||||
|
||||
Three triggers:
|
||||
- NODE_LOOP_ITERATION: write a mid-run snapshot, throttled to at most
|
||||
once every _DIGEST_COOLDOWN seconds per execution.
|
||||
- TOOL_CALL_COMPLETED for delegate_to_sub_agent: same throttled snapshot.
|
||||
Orchestrator nodes often run all subagent calls in a single LLM turn,
|
||||
so NODE_LOOP_ITERATION only fires once at the end. Subagent
|
||||
completions provide intermediate checkpoints.
|
||||
- EXECUTION_COMPLETED / EXECUTION_FAILED: always write the final digest,
|
||||
bypassing the cooldown.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
_DIGEST_COOLDOWN = 300.0 # seconds between mid-run snapshots
|
||||
|
||||
if session.worker_digest_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_digest_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_digest_sub = None
|
||||
|
||||
agent_name = session.worker_path.name if session.worker_path else None
|
||||
if not agent_name:
|
||||
return
|
||||
|
||||
_agent_name = agent_name
|
||||
_llm = session.llm
|
||||
_bus = session.event_bus
|
||||
# per-execution_id monotonic timestamp of last mid-run digest
|
||||
_last_digest: dict[str, float] = {}
|
||||
|
||||
def _resolve_run_id(exec_id: str) -> str | None:
|
||||
"""Look up the run_id for a given execution_id via EXECUTION_STARTED history."""
|
||||
for e in _bus.get_history(event_type=_ET.EXECUTION_STARTED, limit=200):
|
||||
if e.execution_id == exec_id and getattr(e, "run_id", None):
|
||||
return e.run_id
|
||||
return None
|
||||
|
||||
async def _inject_digest_to_queen(run_id: str) -> None:
|
||||
"""Read the written digest and push it into the queen's conversation."""
|
||||
from framework.agents.worker_memory import digest_path
|
||||
|
||||
try:
|
||||
content = digest_path(_agent_name, run_id).read_text(encoding="utf-8").strip()
|
||||
except OSError:
|
||||
return
|
||||
if not content:
|
||||
return
|
||||
executor = session.queen_executor
|
||||
if executor is None:
|
||||
return
|
||||
node = executor.node_registry.get("queen")
|
||||
if node is None or not hasattr(node, "inject_event"):
|
||||
return
|
||||
await node.inject_event(f"[WORKER_DIGEST]\n{content}")
|
||||
|
||||
async def _consolidate_and_notify(run_id: str, outcome_event: Any) -> None:
|
||||
"""Write the digest then push it to the queen."""
|
||||
from framework.agents.worker_memory import consolidate_worker_run
|
||||
|
||||
await consolidate_worker_run(_agent_name, run_id, outcome_event, _bus, _llm)
|
||||
await _inject_digest_to_queen(run_id)
|
||||
|
||||
async def _on_worker_event(event: Any) -> None:
|
||||
if event.stream_id == "queen":
|
||||
return
|
||||
|
||||
exec_id = event.execution_id
|
||||
|
||||
if event.type == _ET.EXECUTION_STARTED:
|
||||
# New run on this execution_id — start the cooldown timer so
|
||||
# mid-run snapshots don't fire immediately at session start.
|
||||
# The first snapshot will happen after _DIGEST_COOLDOWN seconds.
|
||||
if exec_id:
|
||||
_last_digest[exec_id] = _time.monotonic()
|
||||
|
||||
elif event.type in (
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
_ET.EXECUTION_FAILED,
|
||||
_ET.EXECUTION_PAUSED,
|
||||
):
|
||||
# Final digest — always fire, ignore cooldown.
|
||||
# EXECUTION_PAUSED covers cancellation (queen re-triggering the
|
||||
# worker cancels the previous execution, emitting paused).
|
||||
run_id = getattr(event, "run_id", None) or _resolve_run_id(exec_id)
|
||||
if run_id:
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, event),
|
||||
name=f"worker-digest-final-{run_id}",
|
||||
)
|
||||
|
||||
elif event.type in (_ET.NODE_LOOP_ITERATION, _ET.TOOL_CALL_COMPLETED):
|
||||
# Mid-run snapshot — respect 300 s cooldown per execution.
|
||||
# TOOL_CALL_COMPLETED is only interesting for subagent calls;
|
||||
# regular tool completions are too frequent and too cheap.
|
||||
if event.type == _ET.TOOL_CALL_COMPLETED:
|
||||
tool_name = (event.data or {}).get("tool_name", "")
|
||||
if tool_name != "delegate_to_sub_agent":
|
||||
return
|
||||
if not exec_id:
|
||||
return
|
||||
now = _time.monotonic()
|
||||
if now - _last_digest.get(exec_id, 0.0) < _DIGEST_COOLDOWN:
|
||||
return
|
||||
run_id = _resolve_run_id(exec_id)
|
||||
if run_id:
|
||||
_last_digest[exec_id] = now
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, None),
|
||||
name=f"worker-digest-{run_id}",
|
||||
)
|
||||
|
||||
session.worker_digest_sub = session.event_bus.subscribe(
|
||||
event_types=[
|
||||
_ET.EXECUTION_STARTED,
|
||||
_ET.NODE_LOOP_ITERATION,
|
||||
_ET.TOOL_CALL_COMPLETED,
|
||||
_ET.EXECUTION_COMPLETED,
|
||||
_ET.EXECUTION_FAILED,
|
||||
_ET.EXECUTION_PAUSED,
|
||||
],
|
||||
handler=_on_worker_event,
|
||||
)
|
||||
|
||||
def _subscribe_worker_handoffs(self, session: Session, executor: Any) -> None:
|
||||
"""Subscribe queen to worker/subagent escalation handoff events."""
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
@@ -700,16 +926,21 @@ class SessionManager:
|
||||
else None
|
||||
)
|
||||
)
|
||||
_meta_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agent_name": _agent_name,
|
||||
"agent_path": str(session.worker_path) if session.worker_path else None,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
# Merge into existing meta.json to preserve fields written by
|
||||
# _update_meta_json (e.g. phase, agent_path set during building).
|
||||
_existing_meta: dict = {}
|
||||
if _meta_path.exists():
|
||||
try:
|
||||
_existing_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
_new_meta: dict = {"created_at": time.time()}
|
||||
if _agent_name is not None:
|
||||
_new_meta["agent_name"] = _agent_name
|
||||
if session.worker_path is not None:
|
||||
_new_meta["agent_path"] = str(session.worker_path)
|
||||
_existing_meta.update(_new_meta)
|
||||
_meta_path.write_text(json.dumps(_existing_meta), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -719,6 +950,7 @@ class SessionManager:
|
||||
# then use max+1 as offset so resumed sessions produce monotonically
|
||||
# increasing iteration values — preventing frontend message ID collisions.
|
||||
iteration_offset = 0
|
||||
last_phase = ""
|
||||
events_path = queen_dir / "events.jsonl"
|
||||
try:
|
||||
if events_path.exists():
|
||||
@@ -730,17 +962,25 @@ class SessionManager:
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
it = evt.get("data", {}).get("iteration")
|
||||
data = evt.get("data", {})
|
||||
it = data.get("iteration")
|
||||
if isinstance(it, int) and it > max_iter:
|
||||
max_iter = it
|
||||
# Track the latest queen phase from QUEEN_PHASE_CHANGED events
|
||||
if evt.get("type") == "queen_phase_changed":
|
||||
phase = data.get("phase")
|
||||
if phase:
|
||||
last_phase = phase
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
if max_iter >= 0:
|
||||
iteration_offset = max_iter + 1
|
||||
logger.info(
|
||||
"Session '%s' resuming with iteration_offset=%d (from events.jsonl max)",
|
||||
"Session '%s' resuming with iteration_offset=%d"
|
||||
" (from events.jsonl max), last phase: %s",
|
||||
session.id,
|
||||
iteration_offset,
|
||||
last_phase or "unknown",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
@@ -762,11 +1002,27 @@ class SessionManager:
|
||||
try:
|
||||
_meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
_agent_path = _meta.get("agent_path")
|
||||
_phase = _meta.get("phase")
|
||||
|
||||
if _agent_path and Path(_agent_path).exists():
|
||||
await self.load_worker(session.id, _agent_path)
|
||||
if session.phase_state:
|
||||
await session.phase_state.switch_to_staging(source="auto")
|
||||
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
|
||||
if _phase in ("staging", "running", None):
|
||||
# Agent fully built — load worker and resume
|
||||
await self.load_worker(session.id, _agent_path)
|
||||
if session.phase_state:
|
||||
await session.phase_state.switch_to_staging(source="auto")
|
||||
# Emit flowchart overlay so frontend can display it
|
||||
await self._emit_flowchart_on_restore(session, _agent_path)
|
||||
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
|
||||
elif _phase == "building":
|
||||
# Agent folder exists but incomplete — resume building
|
||||
if session.phase_state:
|
||||
session.phase_state.agent_path = _agent_path
|
||||
await session.phase_state.switch_to_building(source="auto")
|
||||
logger.info("Cold restore: resumed BUILDING phase for %s", _agent_path)
|
||||
elif _phase == "planning":
|
||||
if session.phase_state:
|
||||
session.phase_state.agent_path = _agent_path
|
||||
logger.info("Cold restore: PLANNING phase for %s", _agent_path)
|
||||
except Exception:
|
||||
logger.warning("Cold restore: failed to auto-load worker", exc_info=True)
|
||||
|
||||
@@ -776,10 +1032,17 @@ class SessionManager:
|
||||
_consolidation_session_dir = queen_dir
|
||||
|
||||
async def _on_compaction(_event) -> None:
|
||||
# Only consolidate on queen compactions — worker and subagent
|
||||
# compactions are frequent and don't warrant a memory update.
|
||||
if getattr(_event, "stream_id", None) != "queen":
|
||||
return
|
||||
from framework.agents.queen.queen_memory import consolidate_queen_memory
|
||||
|
||||
await consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
asyncio.create_task(
|
||||
consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
),
|
||||
name=f"queen-memory-consolidation-{session.id}",
|
||||
)
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
@@ -841,6 +1104,29 @@ class SessionManager:
|
||||
)
|
||||
)
|
||||
|
||||
async def _emit_flowchart_on_restore(self, session: Session, agent_path: str | Path) -> None:
|
||||
"""Emit FLOWCHART_MAP_UPDATED from persisted flowchart file on cold restore."""
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.tools.flowchart_utils import load_flowchart_file
|
||||
|
||||
original_draft, flowchart_map = load_flowchart_file(agent_path)
|
||||
if original_draft is None:
|
||||
return
|
||||
# Cache in phase_state so the REST endpoint also returns it
|
||||
if session.phase_state:
|
||||
session.phase_state.original_draft_graph = original_draft
|
||||
session.phase_state.flowchart_map = flowchart_map
|
||||
await session.event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.FLOWCHART_MAP_UPDATED,
|
||||
stream_id="queen",
|
||||
data={
|
||||
"map": flowchart_map,
|
||||
"original_draft": original_draft,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def _notify_queen_worker_unloaded(self, session: Session) -> None:
|
||||
"""Notify the queen that the worker has been unloaded."""
|
||||
executor = session.queen_executor
|
||||
@@ -868,6 +1154,10 @@ class SessionManager:
|
||||
event_type = (
|
||||
EventType.TRIGGER_AVAILABLE if kind == "available" else EventType.TRIGGER_REMOVED
|
||||
)
|
||||
# Resolve graph entry node for trigger target
|
||||
runner = getattr(session, "runner", None)
|
||||
graph_entry = runner.graph.entry_node if runner else None
|
||||
|
||||
for t in triggers.values():
|
||||
await session.event_bus.publish(
|
||||
AgentEvent(
|
||||
@@ -877,6 +1167,8 @@ class SessionManager:
|
||||
"trigger_id": t.id,
|
||||
"trigger_type": t.trigger_type,
|
||||
"trigger_config": t.trigger_config,
|
||||
"name": t.description or t.id,
|
||||
**({"entry_node": graph_entry} if graph_entry else {}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ Uses aiohttp TestClient with mocked sessions to test all endpoints
|
||||
without requiring actual LLM calls or agent loading.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -13,6 +14,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.runtime.triggers import TriggerDefinition
|
||||
from framework.server.app import create_app
|
||||
from framework.server.session_manager import Session
|
||||
|
||||
@@ -172,6 +174,7 @@ def _make_session(
|
||||
runner.intro_message = "Test intro"
|
||||
|
||||
mock_event_bus = MagicMock()
|
||||
mock_event_bus.publish = AsyncMock()
|
||||
mock_llm = MagicMock()
|
||||
|
||||
queen_executor = _make_queen_executor() if with_queen else None
|
||||
@@ -210,11 +213,8 @@ def tmp_agent_dir(tmp_path, monkeypatch):
|
||||
return tmp_path, agent_name, base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session(tmp_agent_dir):
|
||||
"""Create a sample session with state.json, checkpoints, and conversations."""
|
||||
tmp_path, agent_name, base = tmp_agent_dir
|
||||
session_id = "session_20260220_120000_abc12345"
|
||||
def _write_sample_session(base: Path, session_id: str):
|
||||
"""Create a sample worker session on disk."""
|
||||
session_dir = base / "sessions" / session_id
|
||||
|
||||
# state.json
|
||||
@@ -295,6 +295,20 @@ def sample_session(tmp_agent_dir):
|
||||
return session_id, session_dir, state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session(tmp_agent_dir):
|
||||
"""Create a sample session with state.json, checkpoints, and conversations."""
|
||||
_tmp_path, _agent_name, base = tmp_agent_dir
|
||||
return _write_sample_session(base, "session_20260220_120000_abc12345")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_id_session(tmp_agent_dir):
|
||||
"""Create a sample session that uses a custom non-session_* ID."""
|
||||
_tmp_path, _agent_name, base = tmp_agent_dir
|
||||
return _write_sample_session(base, "my-custom-session")
|
||||
|
||||
|
||||
def _make_app_with_session(session):
|
||||
"""Create an aiohttp app with a pre-loaded session."""
|
||||
app = create_app()
|
||||
@@ -473,6 +487,70 @@ class TestSessionCRUD:
|
||||
data = await resp.json()
|
||||
assert "primary" in data["graphs"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_trigger_task(self, tmp_path):
|
||||
session = _make_session(tmp_dir=tmp_path)
|
||||
session.available_triggers["daily"] = TriggerDefinition(
|
||||
id="daily",
|
||||
trigger_type="timer",
|
||||
trigger_config={"cron": "0 5 * * *"},
|
||||
task="Old task",
|
||||
)
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
"/api/sessions/test_agent/triggers/daily",
|
||||
json={"task": "New task"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["task"] == "New task"
|
||||
assert data["trigger_config"]["cron"] == "0 5 * * *"
|
||||
assert session.available_triggers["daily"].task == "New task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_trigger_cron_restarts_active_timer(self, tmp_path):
|
||||
session = _make_session(tmp_dir=tmp_path)
|
||||
session.available_triggers["daily"] = TriggerDefinition(
|
||||
id="daily",
|
||||
trigger_type="timer",
|
||||
trigger_config={"cron": "0 5 * * *"},
|
||||
task="Run task",
|
||||
active=True,
|
||||
)
|
||||
session.active_trigger_ids.add("daily")
|
||||
session.active_timer_tasks["daily"] = asyncio.create_task(asyncio.sleep(60))
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
"/api/sessions/test_agent/triggers/daily",
|
||||
json={"trigger_config": {"cron": "0 6 * * *"}},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["trigger_config"]["cron"] == "0 6 * * *"
|
||||
assert "daily" in session.active_timer_tasks
|
||||
assert session.active_timer_tasks["daily"] is not None
|
||||
assert session.available_triggers["daily"].trigger_config["cron"] == "0 6 * * *"
|
||||
session.active_timer_tasks["daily"].cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_trigger_cron_rejects_invalid_expression(self, tmp_path):
|
||||
session = _make_session(tmp_dir=tmp_path)
|
||||
session.available_triggers["daily"] = TriggerDefinition(
|
||||
id="daily",
|
||||
trigger_type="timer",
|
||||
trigger_config={"cron": "0 5 * * *"},
|
||||
task="Run task",
|
||||
)
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
"/api/sessions/test_agent/triggers/daily",
|
||||
json={"trigger_config": {"cron": "not a cron"}},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
class TestExecution:
|
||||
@pytest.mark.asyncio
|
||||
@@ -799,6 +877,22 @@ class TestWorkerSessions:
|
||||
assert data["sessions"][0]["status"] == "paused"
|
||||
assert data["sessions"][0]["steps"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_includes_custom_id(self, custom_id_session, tmp_agent_dir):
|
||||
session_id, session_dir, state = custom_id_session
|
||||
tmp_path, agent_name, base = tmp_agent_dir
|
||||
|
||||
session = _make_session(tmp_dir=tmp_path / ".hive" / "agents" / agent_name)
|
||||
app = _make_app_with_session(session)
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/sessions/test_agent/worker-sessions")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert len(data["sessions"]) == 1
|
||||
assert data["sessions"][0]["session_id"] == session_id
|
||||
assert data["sessions"][0]["status"] == "paused"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_empty(self, tmp_agent_dir):
|
||||
tmp_path, agent_name, base = tmp_agent_dir
|
||||
@@ -1316,6 +1410,28 @@ class TestLogs:
|
||||
assert len(data["logs"]) >= 1
|
||||
assert data["logs"][0]["run_id"] == session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_list_summaries_with_custom_id(self, custom_id_session, tmp_agent_dir):
|
||||
session_id, session_dir, state = custom_id_session
|
||||
tmp_path, agent_name, base = tmp_agent_dir
|
||||
|
||||
from framework.runtime.runtime_log_store import RuntimeLogStore
|
||||
|
||||
log_store = RuntimeLogStore(base)
|
||||
session = _make_session(
|
||||
tmp_dir=tmp_path / ".hive" / "agents" / agent_name,
|
||||
log_store=log_store,
|
||||
)
|
||||
app = _make_app_with_session(session)
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/sessions/test_agent/logs")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert "logs" in data
|
||||
assert len(data["logs"]) >= 1
|
||||
assert data["logs"][0]["run_id"] == session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_session_summary(self, sample_session, tmp_agent_dir):
|
||||
session_id, session_dir, state = sample_session
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Hive Agent Skills — discovery, parsing, trust gating, and injection of SKILL.md packages.
|
||||
|
||||
Implements the open Agent Skills standard (agentskills.io) for portable
|
||||
skill discovery and activation, plus built-in default skills for runtime
|
||||
operational discipline, and AS-13 trust gating for project-scope skills.
|
||||
"""
|
||||
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.models import TrustStatus
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.trust import TrustedRepoStore, TrustGate
|
||||
|
||||
__all__ = [
|
||||
"DefaultSkillConfig",
|
||||
"DefaultSkillManager",
|
||||
"DiscoveryConfig",
|
||||
"ParsedSkill",
|
||||
"SkillCatalog",
|
||||
"SkillDiscovery",
|
||||
"SkillsConfig",
|
||||
"SkillsManager",
|
||||
"SkillsManagerConfig",
|
||||
"TrustGate",
|
||||
"TrustedRepoStore",
|
||||
"TrustStatus",
|
||||
"parse_skill_md",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
---
|
||||
name: hive.batch-ledger
|
||||
description: Track per-item status when processing collections to prevent skipped or duplicated items.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Batch Progress Ledger
|
||||
|
||||
When processing a collection of items, maintain a batch ledger in `_batch_ledger`.
|
||||
|
||||
Initialize when you identify the batch:
|
||||
- `_batch_total`: total item count
|
||||
- `_batch_ledger`: JSON with per-item status
|
||||
|
||||
Per-item statuses: pending → in_progress → completed|failed|skipped
|
||||
|
||||
- Set `in_progress` BEFORE processing
|
||||
- Set final status AFTER processing with 1-line result_summary
|
||||
- Include error reason for failed/skipped items
|
||||
- Update aggregate counts after each item
|
||||
- NEVER remove items from the ledger
|
||||
- If resuming, skip items already marked completed
|
||||
@@ -0,0 +1,22 @@
|
||||
---
|
||||
name: hive.context-preservation
|
||||
description: Proactively preserve critical information before automatic context pruning destroys it.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Context Preservation
|
||||
|
||||
You operate under a finite context window. Important information WILL be pruned.
|
||||
|
||||
Save-As-You-Go: After any tool call producing information you'll need later,
|
||||
immediately extract key data into `_working_notes` or `_preserved_data`.
|
||||
Do NOT rely on referring back to old tool results.
|
||||
|
||||
What to extract: URLs and key snippets (not full pages), relevant API fields
|
||||
(not raw JSON), specific lines/values (not entire files), analysis results
|
||||
(not raw data).
|
||||
|
||||
Before transitioning to the next phase/node, write a handoff summary to
|
||||
`_handoff_context` with everything the next phase needs to know.
|
||||
@@ -0,0 +1,18 @@
|
||||
---
|
||||
name: hive.error-recovery
|
||||
description: Follow a structured recovery protocol when tool calls fail instead of blindly retrying or giving up.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Error Recovery
|
||||
|
||||
When a tool call fails:
|
||||
|
||||
1. Diagnose — record error in notes, classify as transient or structural
|
||||
2. Decide — transient: retry once. Structural fixable: fix and retry.
|
||||
Structural unfixable: record as failed, move to next item.
|
||||
Blocking all progress: record escalation note.
|
||||
3. Adapt — if same tool failed 3+ times, stop using it and find alternative.
|
||||
Update plan in notes. Never silently drop the failed item.
|
||||
@@ -0,0 +1,27 @@
|
||||
---
|
||||
name: hive.note-taking
|
||||
description: Maintain structured working notes throughout execution to prevent information loss during context pruning.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Structured Note-Taking
|
||||
|
||||
Maintain structured working notes in shared memory key `_working_notes`.
|
||||
Update at these checkpoints:
|
||||
|
||||
- After completing each discrete subtask or batch item
|
||||
- After receiving new information that changes your plan
|
||||
- Before any tool call that will produce substantial output
|
||||
|
||||
Structure:
|
||||
|
||||
### Objective — restate the goal
|
||||
### Current Plan — numbered steps, mark completed with ✓
|
||||
### Key Decisions — decisions made and WHY
|
||||
### Working Data — intermediate results, extracted values
|
||||
### Open Questions — uncertainties to verify
|
||||
### Blockers — anything preventing progress
|
||||
|
||||
Update incrementally — do not rewrite from scratch each time.
|
||||
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: hive.quality-monitor
|
||||
description: Periodically self-assess output quality to catch degradation before the judge does.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Quality Self-Assessment
|
||||
|
||||
Every 5 iterations, self-assess:
|
||||
|
||||
1. On-task? Still working toward the stated objective?
|
||||
2. Thorough? Cutting corners compared to earlier?
|
||||
3. Non-repetitive? Producing new value or rehashing?
|
||||
4. Consistent? Latest output contradict earlier decisions?
|
||||
5. Complete? Tracking all items, or silently dropped some?
|
||||
|
||||
If degrading: write assessment to `_quality_log`, re-read `_working_notes`,
|
||||
change approach explicitly. If acceptable: brief note in `_quality_log`.
|
||||
@@ -0,0 +1,17 @@
|
||||
---
|
||||
name: hive.task-decomposition
|
||||
description: Decompose complex tasks into explicit subtasks before diving in.
|
||||
metadata:
|
||||
author: hive
|
||||
type: default-skill
|
||||
---
|
||||
|
||||
## Operational Protocol: Task Decomposition
|
||||
|
||||
Before starting a complex task:
|
||||
|
||||
1. Decompose — break into numbered subtasks in `_working_notes` Current Plan
|
||||
2. Estimate — relative effort per subtask (small/medium/large)
|
||||
3. Execute — work through in order, mark ✓ when complete
|
||||
4. Budget — if running low on iterations, prioritize by impact
|
||||
5. Verify — before declaring done, every subtask must be ✓, skipped (with reason), or blocked
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Skill catalog — in-memory index with system prompt generation.
|
||||
|
||||
Builds the XML catalog injected into the system prompt for model-driven
|
||||
skill activation per the Agent Skills standard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BEHAVIORAL_INSTRUCTION = (
|
||||
"The following skills provide specialized instructions for specific tasks.\n"
|
||||
"When a task matches a skill's description, read the SKILL.md at the listed\n"
|
||||
"location to load the full instructions before proceeding.\n"
|
||||
"When a skill references relative paths, resolve them against the skill's\n"
|
||||
"directory (the parent of SKILL.md) and use absolute paths in tool calls."
|
||||
)
|
||||
|
||||
|
||||
class SkillCatalog:
|
||||
"""In-memory catalog of discovered skills."""
|
||||
|
||||
def __init__(self, skills: list[ParsedSkill] | None = None):
|
||||
self._skills: dict[str, ParsedSkill] = {}
|
||||
self._activated: set[str] = set()
|
||||
if skills:
|
||||
for skill in skills:
|
||||
self.add(skill)
|
||||
|
||||
def add(self, skill: ParsedSkill) -> None:
|
||||
"""Add a skill to the catalog."""
|
||||
self._skills[skill.name] = skill
|
||||
|
||||
def get(self, name: str) -> ParsedSkill | None:
|
||||
"""Look up a skill by name."""
|
||||
return self._skills.get(name)
|
||||
|
||||
def mark_activated(self, name: str) -> None:
|
||||
"""Mark a skill as activated in the current session."""
|
||||
self._activated.add(name)
|
||||
|
||||
def is_activated(self, name: str) -> bool:
|
||||
"""Check if a skill has been activated."""
|
||||
return name in self._activated
|
||||
|
||||
@property
|
||||
def skill_count(self) -> int:
|
||||
return len(self._skills)
|
||||
|
||||
@property
|
||||
def allowlisted_dirs(self) -> list[str]:
|
||||
"""All skill base directories for file access allowlisting."""
|
||||
return [skill.base_dir for skill in self._skills.values()]
|
||||
|
||||
def to_prompt(self) -> str:
|
||||
"""Generate the catalog prompt for system prompt injection.
|
||||
|
||||
Returns empty string if no community/user skills are discovered
|
||||
(default skills are handled separately by DefaultSkillManager).
|
||||
"""
|
||||
# Filter out framework-scope skills (default skills) — they're
|
||||
# injected via the protocols prompt, not the catalog
|
||||
community_skills = [s for s in self._skills.values() if s.source_scope != "framework"]
|
||||
|
||||
if not community_skills:
|
||||
return ""
|
||||
|
||||
lines = ["<available_skills>"]
|
||||
for skill in sorted(community_skills, key=lambda s: s.name):
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{escape(skill.name)}</name>")
|
||||
lines.append(f" <description>{escape(skill.description)}</description>")
|
||||
lines.append(f" <location>{escape(skill.location)}</location>")
|
||||
lines.append(f" <base_dir>{escape(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</available_skills>")
|
||||
|
||||
xml_block = "\n".join(lines)
|
||||
return f"{_BEHAVIORAL_INSTRUCTION}\n\n{xml_block}"
|
||||
|
||||
def build_pre_activated_prompt(self, skill_names: list[str]) -> str:
|
||||
"""Build prompt content for pre-activated skills.
|
||||
|
||||
Pre-activated skills get their full SKILL.md body loaded into
|
||||
the system prompt at startup (tier 2), bypassing model-driven
|
||||
activation.
|
||||
|
||||
Returns empty string if no skills match.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
for name in skill_names:
|
||||
skill = self.get(name)
|
||||
if skill is None:
|
||||
logger.warning("Pre-activated skill '%s' not found in catalog", name)
|
||||
continue
|
||||
if self.is_activated(name):
|
||||
continue # Already activated, skip duplicate
|
||||
|
||||
self.mark_activated(name)
|
||||
parts.append(f"--- Pre-Activated Skill: {skill.name} ---\n{skill.body}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
@@ -0,0 +1,120 @@
|
||||
"""CLI commands for the Hive skill system.
|
||||
|
||||
Phase 1 commands (AS-13):
|
||||
hive skill list — list discovered skills across all scopes
|
||||
hive skill trust <path> — permanently trust a project repo's skills
|
||||
|
||||
Full CLI suite (CLI-1 through CLI-13) is Phase 2.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def register_skill_commands(subparsers) -> None:
|
||||
"""Register the ``hive skill`` subcommand group."""
|
||||
skill_parser = subparsers.add_parser("skill", help="Manage skills")
|
||||
skill_sub = skill_parser.add_subparsers(dest="skill_command", required=True)
|
||||
|
||||
# hive skill list
|
||||
list_parser = skill_sub.add_parser("list", help="List discovered skills across all scopes")
|
||||
list_parser.add_argument(
|
||||
"--project-dir",
|
||||
default=None,
|
||||
metavar="PATH",
|
||||
help="Project directory to scan (default: current directory)",
|
||||
)
|
||||
list_parser.set_defaults(func=cmd_skill_list)
|
||||
|
||||
# hive skill trust
|
||||
trust_parser = skill_sub.add_parser(
|
||||
"trust",
|
||||
help="Permanently trust a project repository so its skills load without prompting",
|
||||
)
|
||||
trust_parser.add_argument(
|
||||
"project_path",
|
||||
help="Path to the project directory (must contain a .git with a remote origin)",
|
||||
)
|
||||
trust_parser.set_defaults(func=cmd_skill_trust)
|
||||
|
||||
|
||||
def cmd_skill_list(args) -> int:
|
||||
"""List all discovered skills grouped by scope."""
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
|
||||
project_dir = Path(args.project_dir).resolve() if args.project_dir else Path.cwd()
|
||||
skills = SkillDiscovery(DiscoveryConfig(project_root=project_dir)).discover()
|
||||
|
||||
if not skills:
|
||||
print("No skills discovered.")
|
||||
return 0
|
||||
|
||||
scope_headers = {
|
||||
"project": "PROJECT SKILLS",
|
||||
"user": "USER SKILLS",
|
||||
"framework": "FRAMEWORK SKILLS",
|
||||
}
|
||||
|
||||
for scope in ("project", "user", "framework"):
|
||||
scope_skills = [s for s in skills if s.source_scope == scope]
|
||||
if not scope_skills:
|
||||
continue
|
||||
print(f"\n{scope_headers[scope]}")
|
||||
print("─" * 40)
|
||||
for skill in scope_skills:
|
||||
print(f" • {skill.name}")
|
||||
print(f" {skill.description}")
|
||||
print(f" {skill.location}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_skill_trust(args) -> int:
|
||||
"""Permanently trust a project repository's skills."""
|
||||
from framework.skills.trust import TrustedRepoStore, _normalize_remote_url
|
||||
|
||||
project_path = Path(args.project_path).resolve()
|
||||
|
||||
if not project_path.exists():
|
||||
print(f"Error: path does not exist: {project_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if not (project_path / ".git").exists():
|
||||
print(
|
||||
f"Error: {project_path} is not a git repository (no .git directory).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(project_path), "remote", "get-url", "origin"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(
|
||||
"Error: no remote 'origin' configured in this repository.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
remote_url = result.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
print("Error: git remote lookup timed out.", file=sys.stderr)
|
||||
return 1
|
||||
except (FileNotFoundError, OSError) as e:
|
||||
print(f"Error reading git remote: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo_key = _normalize_remote_url(remote_url)
|
||||
store = TrustedRepoStore()
|
||||
store.trust(repo_key, project_path=str(project_path))
|
||||
|
||||
print(f"✓ Trusted: {repo_key}")
|
||||
print(" Stored in ~/.hive/trusted_repos.json")
|
||||
print(" Skills from this repository will load without prompting in future runs.")
|
||||
return 0
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Skill configuration dataclasses.
|
||||
|
||||
Handles agent-level skill configuration from module-level variables
|
||||
(``default_skills`` and ``skills``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultSkillConfig:
|
||||
"""Configuration for a single default skill."""
|
||||
|
||||
enabled: bool = True
|
||||
overrides: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DefaultSkillConfig:
|
||||
enabled = data.get("enabled", True)
|
||||
overrides = {k: v for k, v in data.items() if k != "enabled"}
|
||||
return cls(enabled=enabled, overrides=overrides)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillsConfig:
|
||||
"""Agent-level skill configuration.
|
||||
|
||||
Built from module-level variables in agent.py::
|
||||
|
||||
# Pre-activated community skills
|
||||
skills = ["deep-research", "code-review"]
|
||||
|
||||
# Default skill configuration
|
||||
default_skills = {
|
||||
"hive.note-taking": {"enabled": True},
|
||||
"hive.batch-ledger": {"enabled": True, "checkpoint_every_n": 10},
|
||||
"hive.quality-monitor": {"enabled": False},
|
||||
}
|
||||
"""
|
||||
|
||||
# Per-default-skill config, keyed by skill name (e.g. "hive.note-taking")
|
||||
default_skills: dict[str, DefaultSkillConfig] = field(default_factory=dict)
|
||||
|
||||
# Pre-activated community skills (by name)
|
||||
skills: list[str] = field(default_factory=list)
|
||||
|
||||
# Master switch: disable all default skills at once
|
||||
all_defaults_disabled: bool = False
|
||||
|
||||
def is_default_enabled(self, skill_name: str) -> bool:
|
||||
"""Check if a specific default skill is enabled."""
|
||||
if self.all_defaults_disabled:
|
||||
return False
|
||||
config = self.default_skills.get(skill_name)
|
||||
if config is None:
|
||||
return True # enabled by default
|
||||
return config.enabled
|
||||
|
||||
def get_default_overrides(self, skill_name: str) -> dict[str, Any]:
|
||||
"""Get skill-specific configuration overrides."""
|
||||
config = self.default_skills.get(skill_name)
|
||||
if config is None:
|
||||
return {}
|
||||
return config.overrides
|
||||
|
||||
@classmethod
|
||||
def from_agent_vars(
|
||||
cls,
|
||||
default_skills: dict[str, Any] | None = None,
|
||||
skills: list[str] | None = None,
|
||||
) -> SkillsConfig:
|
||||
"""Build config from agent module-level variables.
|
||||
|
||||
Args:
|
||||
default_skills: Dict from agent module, e.g.
|
||||
``{"hive.note-taking": {"enabled": True}}``
|
||||
skills: List of pre-activated skill names from agent module
|
||||
"""
|
||||
all_disabled = False
|
||||
parsed_defaults: dict[str, DefaultSkillConfig] = {}
|
||||
|
||||
if default_skills:
|
||||
for name, config_dict in default_skills.items():
|
||||
if name == "_all":
|
||||
if isinstance(config_dict, dict) and not config_dict.get("enabled", True):
|
||||
all_disabled = True
|
||||
continue
|
||||
if isinstance(config_dict, dict):
|
||||
parsed_defaults[name] = DefaultSkillConfig.from_dict(config_dict)
|
||||
elif isinstance(config_dict, bool):
|
||||
parsed_defaults[name] = DefaultSkillConfig(enabled=config_dict)
|
||||
|
||||
return cls(
|
||||
default_skills=parsed_defaults,
|
||||
skills=list(skills or []),
|
||||
all_defaults_disabled=all_disabled,
|
||||
)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""DefaultSkillManager — load, configure, and inject built-in default skills.
|
||||
|
||||
Default skills are SKILL.md packages shipped with the framework that provide
|
||||
runtime operational protocols (note-taking, batch tracking, error recovery, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default skills directory relative to this module
|
||||
_DEFAULT_SKILLS_DIR = Path(__file__).parent / "_default_skills"
|
||||
|
||||
# Ordered list of default skills (name → directory)
|
||||
SKILL_REGISTRY: dict[str, str] = {
|
||||
"hive.note-taking": "note-taking",
|
||||
"hive.batch-ledger": "batch-ledger",
|
||||
"hive.context-preservation": "context-preservation",
|
||||
"hive.quality-monitor": "quality-monitor",
|
||||
"hive.error-recovery": "error-recovery",
|
||||
"hive.task-decomposition": "task-decomposition",
|
||||
}
|
||||
|
||||
# All shared memory keys used by default skills (for permission auto-inclusion)
|
||||
SHARED_MEMORY_KEYS: list[str] = [
|
||||
# note-taking
|
||||
"_working_notes",
|
||||
"_notes_updated_at",
|
||||
# batch-ledger
|
||||
"_batch_ledger",
|
||||
"_batch_total",
|
||||
"_batch_completed",
|
||||
"_batch_failed",
|
||||
# context-preservation
|
||||
"_handoff_context",
|
||||
"_preserved_data",
|
||||
# quality-monitor
|
||||
"_quality_log",
|
||||
"_quality_degradation_count",
|
||||
# error-recovery
|
||||
"_error_log",
|
||||
"_failed_tools",
|
||||
"_escalation_needed",
|
||||
# task-decomposition
|
||||
"_subtasks",
|
||||
"_iteration_budget_remaining",
|
||||
]
|
||||
|
||||
|
||||
class DefaultSkillManager:
|
||||
"""Manages loading, configuration, and prompt generation for default skills."""
|
||||
|
||||
def __init__(self, config: SkillsConfig | None = None):
|
||||
self._config = config or SkillsConfig()
|
||||
self._skills: dict[str, ParsedSkill] = {}
|
||||
self._loaded = False
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load all enabled default skill SKILL.md files."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
for skill_name, dir_name in SKILL_REGISTRY.items():
|
||||
if not self._config.is_default_enabled(skill_name):
|
||||
logger.info("Default skill '%s' disabled by config", skill_name)
|
||||
continue
|
||||
|
||||
skill_path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
if not skill_path.is_file():
|
||||
logger.error("Default skill SKILL.md not found: %s", skill_path)
|
||||
continue
|
||||
|
||||
parsed = parse_skill_md(skill_path, source_scope="framework")
|
||||
if parsed is None:
|
||||
logger.error("Failed to parse default skill: %s", skill_path)
|
||||
continue
|
||||
|
||||
self._skills[skill_name] = parsed
|
||||
|
||||
self._loaded = True
|
||||
|
||||
def build_protocols_prompt(self) -> str:
|
||||
"""Build the combined operational protocols section.
|
||||
|
||||
Extracts protocol sections from all enabled default skills and
|
||||
combines them into a single ``## Operational Protocols`` block
|
||||
for system prompt injection.
|
||||
|
||||
Returns empty string if all defaults are disabled.
|
||||
"""
|
||||
if not self._skills:
|
||||
return ""
|
||||
|
||||
parts: list[str] = ["## Operational Protocols\n"]
|
||||
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill is None:
|
||||
continue
|
||||
# Use the full body — each SKILL.md contains exactly one protocol section
|
||||
parts.append(skill.body)
|
||||
|
||||
if len(parts) <= 1:
|
||||
return ""
|
||||
|
||||
combined = "\n\n".join(parts)
|
||||
|
||||
# Token budget warning (approximate: 1 token ≈ 4 chars)
|
||||
approx_tokens = len(combined) // 4
|
||||
if approx_tokens > 2000:
|
||||
logger.warning(
|
||||
"Default skill protocols exceed 2000 token budget "
|
||||
"(~%d tokens, %d chars). Consider trimming.",
|
||||
approx_tokens,
|
||||
len(combined),
|
||||
)
|
||||
|
||||
return combined
|
||||
|
||||
def log_active_skills(self) -> None:
|
||||
"""Log which default skills are active and their configuration."""
|
||||
if not self._skills:
|
||||
logger.info("Default skills: all disabled")
|
||||
return
|
||||
|
||||
active = []
|
||||
for skill_name in SKILL_REGISTRY:
|
||||
if skill_name in self._skills:
|
||||
overrides = self._config.get_default_overrides(skill_name)
|
||||
if overrides:
|
||||
active.append(f"{skill_name} ({overrides})")
|
||||
else:
|
||||
active.append(skill_name)
|
||||
|
||||
logger.info("Default skills active: %s", ", ".join(active))
|
||||
|
||||
@property
|
||||
def active_skill_names(self) -> list[str]:
|
||||
"""Names of all currently active default skills."""
|
||||
return list(self._skills.keys())
|
||||
|
||||
@property
|
||||
def active_skills(self) -> dict[str, ParsedSkill]:
|
||||
"""All active default skills keyed by name."""
|
||||
return dict(self._skills)
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Skill discovery — scan standard directories for SKILL.md files.
|
||||
|
||||
Implements the Agent Skills standard discovery paths plus Hive-specific
|
||||
locations. Resolves name collisions deterministically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Directories to skip during scanning
|
||||
_SKIP_DIRS = frozenset(
|
||||
{
|
||||
".git",
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".mypy_cache",
|
||||
".pytest_cache",
|
||||
".ruff_cache",
|
||||
}
|
||||
)
|
||||
|
||||
# Scope priority (higher = takes precedence)
|
||||
_SCOPE_PRIORITY = {
|
||||
"framework": 0,
|
||||
"user": 1,
|
||||
"project": 2,
|
||||
}
|
||||
|
||||
# Within the same scope, Hive-specific paths override cross-client paths.
|
||||
# We encode this by scanning cross-client first, then Hive-specific (later wins).
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveryConfig:
|
||||
"""Configuration for skill discovery."""
|
||||
|
||||
project_root: Path | None = None
|
||||
skip_user_scope: bool = False
|
||||
skip_framework_scope: bool = False
|
||||
max_depth: int = 4
|
||||
max_dirs: int = 2000
|
||||
|
||||
|
||||
class SkillDiscovery:
|
||||
"""Scans standard directories for SKILL.md files and resolves collisions."""
|
||||
|
||||
def __init__(self, config: DiscoveryConfig | None = None):
|
||||
self._config = config or DiscoveryConfig()
|
||||
|
||||
def discover(self) -> list[ParsedSkill]:
|
||||
"""Scan all scopes and return deduplicated skill list.
|
||||
|
||||
Scanning order (lowest to highest precedence):
|
||||
1. Framework defaults
|
||||
2. User cross-client (~/.agents/skills/)
|
||||
3. User Hive-specific (~/.hive/skills/)
|
||||
4. Project cross-client (<project>/.agents/skills/)
|
||||
5. Project Hive-specific (<project>/.hive/skills/)
|
||||
|
||||
Later entries override earlier ones on name collision.
|
||||
"""
|
||||
all_skills: list[ParsedSkill] = []
|
||||
|
||||
# Framework scope (lowest precedence)
|
||||
if not self._config.skip_framework_scope:
|
||||
framework_dir = Path(__file__).parent / "_default_skills"
|
||||
if framework_dir.is_dir():
|
||||
all_skills.extend(self._scan_scope(framework_dir, "framework"))
|
||||
|
||||
# User scope
|
||||
if not self._config.skip_user_scope:
|
||||
home = Path.home()
|
||||
|
||||
# Cross-client (lower precedence within user scope)
|
||||
user_agents = home / ".agents" / "skills"
|
||||
if user_agents.is_dir():
|
||||
all_skills.extend(self._scan_scope(user_agents, "user"))
|
||||
|
||||
# Hive-specific (higher precedence within user scope)
|
||||
user_hive = home / ".hive" / "skills"
|
||||
if user_hive.is_dir():
|
||||
all_skills.extend(self._scan_scope(user_hive, "user"))
|
||||
|
||||
# Project scope (highest precedence)
|
||||
if self._config.project_root:
|
||||
root = self._config.project_root
|
||||
|
||||
# Cross-client
|
||||
project_agents = root / ".agents" / "skills"
|
||||
if project_agents.is_dir():
|
||||
all_skills.extend(self._scan_scope(project_agents, "project"))
|
||||
|
||||
# Hive-specific
|
||||
project_hive = root / ".hive" / "skills"
|
||||
if project_hive.is_dir():
|
||||
all_skills.extend(self._scan_scope(project_hive, "project"))
|
||||
|
||||
resolved = self._resolve_collisions(all_skills)
|
||||
|
||||
logger.info(
|
||||
"Skill discovery: found %d skills (%d after dedup) across all scopes",
|
||||
len(all_skills),
|
||||
len(resolved),
|
||||
)
|
||||
return resolved
|
||||
|
||||
def _scan_scope(self, root: Path, scope: str) -> list[ParsedSkill]:
|
||||
"""Scan a single directory for skill directories containing SKILL.md."""
|
||||
skills: list[ParsedSkill] = []
|
||||
dirs_scanned = 0
|
||||
|
||||
for skill_md in self._find_skill_files(root, depth=0):
|
||||
if dirs_scanned >= self._config.max_dirs:
|
||||
logger.warning(
|
||||
"Hit max directory limit (%d) scanning %s",
|
||||
self._config.max_dirs,
|
||||
root,
|
||||
)
|
||||
break
|
||||
|
||||
parsed = parse_skill_md(skill_md, source_scope=scope)
|
||||
if parsed is not None:
|
||||
skills.append(parsed)
|
||||
dirs_scanned += 1
|
||||
|
||||
return skills
|
||||
|
||||
def _find_skill_files(self, directory: Path, depth: int) -> list[Path]:
|
||||
"""Recursively find SKILL.md files up to max_depth."""
|
||||
if depth > self._config.max_depth:
|
||||
return []
|
||||
|
||||
results: list[Path] = []
|
||||
|
||||
try:
|
||||
entries = sorted(directory.iterdir())
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
for entry in entries:
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
if entry.name in _SKIP_DIRS:
|
||||
continue
|
||||
|
||||
skill_md = entry / "SKILL.md"
|
||||
if skill_md.is_file():
|
||||
results.append(skill_md)
|
||||
else:
|
||||
# Recurse into subdirectories
|
||||
results.extend(self._find_skill_files(entry, depth + 1))
|
||||
|
||||
return results
|
||||
|
||||
def _resolve_collisions(self, skills: list[ParsedSkill]) -> list[ParsedSkill]:
|
||||
"""Resolve name collisions deterministically.
|
||||
|
||||
Later entries in the list override earlier ones (because we scan
|
||||
from lowest to highest precedence). On collision, log a warning.
|
||||
"""
|
||||
seen: dict[str, ParsedSkill] = {}
|
||||
|
||||
for skill in skills:
|
||||
if skill.name in seen:
|
||||
existing = seen[skill.name]
|
||||
logger.warning(
|
||||
"Skill name collision: '%s' from %s overrides %s",
|
||||
skill.name,
|
||||
skill.location,
|
||||
existing.location,
|
||||
)
|
||||
seen[skill.name] = skill
|
||||
|
||||
return list(seen.values())
|
||||
@@ -0,0 +1,184 @@
|
||||
"""Unified skill lifecycle manager.
|
||||
|
||||
``SkillsManager`` is the single facade that owns skill discovery, loading,
|
||||
and prompt renderation. The runtime creates one at startup and downstream
|
||||
layers read the cached prompt strings.
|
||||
|
||||
Typical usage — **config-driven** (runner passes configuration)::
|
||||
|
||||
config = SkillsManagerConfig(
|
||||
skills_config=SkillsConfig.from_agent_vars(...),
|
||||
project_root=agent_path,
|
||||
)
|
||||
mgr = SkillsManager(config)
|
||||
mgr.load()
|
||||
print(mgr.protocols_prompt) # default skill protocols
|
||||
print(mgr.skills_catalog_prompt) # community skills XML
|
||||
|
||||
Typical usage — **bare** (exported agents, SDK users)::
|
||||
|
||||
mgr = SkillsManager() # default config
|
||||
mgr.load() # loads all 6 default skills, no community discovery
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillsManagerConfig:
|
||||
"""Everything the runtime needs to configure skills.
|
||||
|
||||
Attributes:
|
||||
skills_config: Per-skill enable/disable and overrides.
|
||||
project_root: Agent directory for community skill discovery.
|
||||
When ``None``, community discovery is skipped.
|
||||
skip_community_discovery: Explicitly skip community scanning
|
||||
even when ``project_root`` is set.
|
||||
interactive: Whether trust gating can prompt the user interactively.
|
||||
When ``False``, untrusted project skills are silently skipped.
|
||||
"""
|
||||
|
||||
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
|
||||
project_root: Path | None = None
|
||||
skip_community_discovery: bool = False
|
||||
interactive: bool = True
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
"""Unified skill lifecycle: discovery → loading → prompt renderation.
|
||||
|
||||
The runtime creates one instance during init and owns it for the
|
||||
lifetime of the process. Downstream layers (``ExecutionStream``,
|
||||
``GraphExecutor``, ``NodeContext``, ``EventLoopNode``) receive the
|
||||
cached prompt strings via property accessors.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SkillsManagerConfig | None = None) -> None:
|
||||
self._config = config or SkillsManagerConfig()
|
||||
self._loaded = False
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
self._allowlisted_dirs: list[str] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_precomputed(
|
||||
cls,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
) -> SkillsManager:
|
||||
"""Wrap pre-rendered prompt strings (legacy callers).
|
||||
|
||||
Returns a manager that skips discovery/loading and just returns
|
||||
the provided strings. Used by the deprecation bridge in
|
||||
``AgentRuntime`` when callers pass raw prompt strings.
|
||||
"""
|
||||
mgr = cls.__new__(cls)
|
||||
mgr._config = SkillsManagerConfig()
|
||||
mgr._loaded = True # skip load()
|
||||
mgr._catalog_prompt = skills_catalog_prompt
|
||||
mgr._protocols_prompt = protocols_prompt
|
||||
mgr._allowlisted_dirs = []
|
||||
return mgr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load(self) -> None:
|
||||
"""Discover, load, and cache skill prompts. Idempotent."""
|
||||
if self._loaded:
|
||||
return
|
||||
self._loaded = True
|
||||
|
||||
try:
|
||||
self._do_load()
|
||||
except Exception:
|
||||
logger.warning("Skill system init failed (non-fatal)", exc_info=True)
|
||||
|
||||
def _do_load(self) -> None:
|
||||
"""Internal load — may raise; caller catches."""
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
|
||||
skills_config = self._config.skills_config
|
||||
|
||||
# 1. Community skill discovery (when project_root is available)
|
||||
catalog_prompt = ""
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
from framework.skills.trust import TrustGate
|
||||
|
||||
discovery = SkillDiscovery(DiscoveryConfig(project_root=self._config.project_root))
|
||||
discovered = discovery.discover()
|
||||
|
||||
# Trust-gate project-scope skills (AS-13)
|
||||
discovered = TrustGate(interactive=self._config.interactive).filter_and_gate(
|
||||
discovered, project_dir=self._config.project_root
|
||||
)
|
||||
|
||||
catalog = SkillCatalog(discovered)
|
||||
self._allowlisted_dirs = catalog.allowlisted_dirs
|
||||
catalog_prompt = catalog.to_prompt()
|
||||
|
||||
# Pre-activated community skills
|
||||
if skills_config.skills:
|
||||
pre_activated = catalog.build_pre_activated_prompt(skills_config.skills)
|
||||
if pre_activated:
|
||||
if catalog_prompt:
|
||||
catalog_prompt = f"{catalog_prompt}\n\n{pre_activated}"
|
||||
else:
|
||||
catalog_prompt = pre_activated
|
||||
|
||||
# 2. Default skills (always loaded unless explicitly disabled)
|
||||
default_mgr = DefaultSkillManager(config=skills_config)
|
||||
default_mgr.load()
|
||||
default_mgr.log_active_skills()
|
||||
protocols_prompt = default_mgr.build_protocols_prompt()
|
||||
|
||||
# 3. Cache
|
||||
self._catalog_prompt = catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
|
||||
if protocols_prompt:
|
||||
logger.info(
|
||||
"Skill system ready: protocols=%d chars, catalog=%d chars",
|
||||
len(protocols_prompt),
|
||||
len(catalog_prompt),
|
||||
)
|
||||
else:
|
||||
logger.warning("Skill system produced empty protocols_prompt")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt accessors (consumed by downstream layers)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def skills_catalog_prompt(self) -> str:
|
||||
"""Community skills XML catalog for system prompt injection."""
|
||||
return self._catalog_prompt
|
||||
|
||||
@property
|
||||
def protocols_prompt(self) -> str:
|
||||
"""Default skill operational protocols for system prompt injection."""
|
||||
return self._protocols_prompt
|
||||
|
||||
@property
|
||||
def allowlisted_dirs(self) -> list[str]:
|
||||
"""Skill base directories for Tier 3 resource access (AS-6)."""
|
||||
return self._allowlisted_dirs
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Data models for the Hive skill system (Agent Skills standard)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SkillScope(StrEnum):
|
||||
"""Where a skill was discovered."""
|
||||
|
||||
PROJECT = "project"
|
||||
USER = "user"
|
||||
FRAMEWORK = "framework"
|
||||
|
||||
|
||||
class TrustStatus(StrEnum):
|
||||
"""Trust state of a skill entry."""
|
||||
|
||||
TRUSTED = "trusted"
|
||||
PENDING_CONSENT = "pending_consent"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""In-memory record for a discovered skill (PRD §4.2)."""
|
||||
|
||||
name: str
|
||||
"""Skill name from SKILL.md frontmatter."""
|
||||
|
||||
description: str
|
||||
"""Skill description from SKILL.md frontmatter."""
|
||||
|
||||
location: Path
|
||||
"""Absolute path to SKILL.md."""
|
||||
|
||||
base_dir: Path
|
||||
"""Parent directory of SKILL.md (skill root)."""
|
||||
|
||||
source_scope: SkillScope
|
||||
"""Which scope this skill was found in."""
|
||||
|
||||
trust_status: TrustStatus = TrustStatus.TRUSTED
|
||||
"""Trust state; project-scope skills start as PENDING_CONSENT before gating."""
|
||||
|
||||
# Optional frontmatter fields
|
||||
license: str | None = None
|
||||
compatibility: list[str] = field(default_factory=list)
|
||||
allowed_tools: list[str] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
@@ -0,0 +1,158 @@
|
||||
"""SKILL.md parser — extracts YAML frontmatter and markdown body.
|
||||
|
||||
Parses SKILL.md files per the Agent Skills standard (agentskills.io/specification).
|
||||
Lenient validation: warns on non-critical issues, skips only on missing description
|
||||
or completely unparseable YAML.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum name length before a warning is logged
|
||||
_MAX_NAME_LENGTH = 64
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedSkill:
|
||||
"""In-memory representation of a parsed SKILL.md file."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
location: str # absolute path to SKILL.md
|
||||
base_dir: str # parent directory of SKILL.md
|
||||
source_scope: str # "project", "user", or "framework"
|
||||
body: str # markdown body after closing ---
|
||||
|
||||
# Optional frontmatter fields
|
||||
license: str | None = None
|
||||
compatibility: list[str] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
allowed_tools: list[str] | None = None
|
||||
|
||||
|
||||
def _try_fix_yaml(raw: str) -> str:
|
||||
"""Attempt to fix common YAML issues (unquoted colon values).
|
||||
|
||||
Some SKILL.md files written for other clients may contain unquoted
|
||||
values with colons, e.g. ``description: Use for: research tasks``.
|
||||
This wraps such values in quotes as a best-effort fixup.
|
||||
"""
|
||||
lines = raw.split("\n")
|
||||
fixed = []
|
||||
for line in lines:
|
||||
# Match "key: value" where value contains an unquoted colon
|
||||
m = re.match(r"^(\s*\w[\w-]*:\s*)(.+)$", line)
|
||||
if m:
|
||||
key_part, value_part = m.group(1), m.group(2)
|
||||
# If value contains a colon and isn't already quoted
|
||||
if ":" in value_part and not (value_part.startswith('"') or value_part.startswith("'")):
|
||||
value_part = f'"{value_part}"'
|
||||
fixed.append(f"{key_part}{value_part}")
|
||||
else:
|
||||
fixed.append(line)
|
||||
return "\n".join(fixed)
|
||||
|
||||
|
||||
def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | None:
|
||||
"""Parse a SKILL.md file into a ParsedSkill record.
|
||||
|
||||
Args:
|
||||
path: Absolute path to the SKILL.md file.
|
||||
source_scope: One of "project", "user", or "framework".
|
||||
|
||||
Returns:
|
||||
ParsedSkill on success, None if the file is unparseable or
|
||||
missing required fields (description).
|
||||
"""
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.error("Failed to read %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
if not content.strip():
|
||||
logger.error("Empty SKILL.md: %s", path)
|
||||
return None
|
||||
|
||||
# Split on --- delimiters (first two occurrences)
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
logger.error("SKILL.md missing YAML frontmatter delimiters (---): %s", path)
|
||||
return None
|
||||
|
||||
# parts[0] is content before first --- (should be empty or whitespace)
|
||||
# parts[1] is the YAML frontmatter
|
||||
# parts[2] is the markdown body
|
||||
raw_yaml = parts[1].strip()
|
||||
body = parts[2].strip()
|
||||
|
||||
if not raw_yaml:
|
||||
logger.error("Empty YAML frontmatter in %s", path)
|
||||
return None
|
||||
|
||||
# Parse YAML
|
||||
import yaml
|
||||
|
||||
frontmatter: dict[str, Any] | None = None
|
||||
try:
|
||||
frontmatter = yaml.safe_load(raw_yaml)
|
||||
except yaml.YAMLError:
|
||||
# Fallback: try fixing unquoted colon values
|
||||
try:
|
||||
fixed = _try_fix_yaml(raw_yaml)
|
||||
frontmatter = yaml.safe_load(fixed)
|
||||
logger.warning("Fixed YAML parse issues in %s (unquoted colons)", path)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.error("Unparseable YAML in %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter, dict):
|
||||
logger.error("YAML frontmatter is not a mapping in %s", path)
|
||||
return None
|
||||
|
||||
# Required: description
|
||||
description = frontmatter.get("description")
|
||||
if not description or not str(description).strip():
|
||||
logger.error("Missing or empty 'description' in %s — skipping skill", path)
|
||||
return None
|
||||
|
||||
# Required: name (fallback to parent directory name)
|
||||
name = frontmatter.get("name")
|
||||
parent_dir_name = path.parent.name
|
||||
if not name or not str(name).strip():
|
||||
name = parent_dir_name
|
||||
logger.warning("Missing 'name' in %s — using directory name '%s'", path, name)
|
||||
else:
|
||||
name = str(name).strip()
|
||||
|
||||
# Lenient warnings
|
||||
if len(name) > _MAX_NAME_LENGTH:
|
||||
logger.warning("Skill name exceeds %d chars in %s: '%s'", _MAX_NAME_LENGTH, path, name)
|
||||
|
||||
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
|
||||
logger.warning(
|
||||
"Skill name '%s' doesn't match parent directory '%s' in %s",
|
||||
name,
|
||||
parent_dir_name,
|
||||
path,
|
||||
)
|
||||
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=str(description).strip(),
|
||||
location=str(path.resolve()),
|
||||
base_dir=str(path.parent.resolve()),
|
||||
source_scope=source_scope,
|
||||
body=body,
|
||||
license=frontmatter.get("license"),
|
||||
compatibility=frontmatter.get("compatibility"),
|
||||
metadata=frontmatter.get("metadata"),
|
||||
allowed_tools=frontmatter.get("allowed-tools"),
|
||||
)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""Trust gating for project-level skills (PRD AS-13).
|
||||
|
||||
Project-level skills from untrusted repositories require explicit user consent
|
||||
before their instructions are loaded into the agent's system prompt.
|
||||
Framework and user-scope skills are always trusted.
|
||||
|
||||
Trusted repos are persisted at ~/.hive/trusted_repos.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Env var to bypass trust gating in CI/headless pipelines (opt-in).
|
||||
_ENV_TRUST_ALL = "HIVE_TRUST_PROJECT_SKILLS"
|
||||
|
||||
# Env var for comma-separated own-remote glob patterns (e.g. "github.com/myorg/*").
|
||||
_ENV_OWN_REMOTES = "HIVE_OWN_REMOTES"
|
||||
|
||||
_TRUSTED_REPOS_PATH = Path.home() / ".hive" / "trusted_repos.json"
|
||||
_NOTICE_SENTINEL_PATH = Path.home() / ".hive" / ".skill_trust_notice_shown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trusted repo store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrustedRepoEntry:
|
||||
repo_key: str
|
||||
added_at: datetime
|
||||
project_path: str = ""
|
||||
|
||||
|
||||
class TrustedRepoStore:
|
||||
"""Persists permanently-trusted repo keys to ~/.hive/trusted_repos.json."""
|
||||
|
||||
def __init__(self, path: Path | None = None) -> None:
|
||||
self._path = path or _TRUSTED_REPOS_PATH
|
||||
self._entries: dict[str, TrustedRepoEntry] = {}
|
||||
self._loaded = False
|
||||
|
||||
def is_trusted(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
return repo_key in self._entries
|
||||
|
||||
def trust(self, repo_key: str, project_path: str = "") -> None:
|
||||
self._ensure_loaded()
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=datetime.now(tz=UTC),
|
||||
project_path=project_path,
|
||||
)
|
||||
self._save()
|
||||
logger.info("skill_trust_store: trusted repo_key=%s", repo_key)
|
||||
|
||||
def revoke(self, repo_key: str) -> bool:
|
||||
self._ensure_loaded()
|
||||
if repo_key in self._entries:
|
||||
del self._entries[repo_key]
|
||||
self._save()
|
||||
logger.info("skill_trust_store: revoked repo_key=%s", repo_key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_entries(self) -> list[TrustedRepoEntry]:
|
||||
self._ensure_loaded()
|
||||
return list(self._entries.values())
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
if not self._loaded:
|
||||
self._load()
|
||||
self._loaded = True
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
data = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
for raw in data.get("entries", []):
|
||||
repo_key = raw.get("repo_key", "")
|
||||
if not repo_key:
|
||||
continue
|
||||
try:
|
||||
added_at = datetime.fromisoformat(raw["added_at"])
|
||||
except (KeyError, ValueError):
|
||||
added_at = datetime.now(tz=UTC)
|
||||
self._entries[repo_key] = TrustedRepoEntry(
|
||||
repo_key=repo_key,
|
||||
added_at=added_at,
|
||||
project_path=raw.get("project_path", ""),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"skill_trust_store: could not read %s (%s); treating as empty",
|
||||
self._path,
|
||||
e,
|
||||
)
|
||||
|
||||
def _save(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"version": 1,
|
||||
"entries": [
|
||||
{
|
||||
"repo_key": e.repo_key,
|
||||
"added_at": e.added_at.isoformat(),
|
||||
"project_path": e.project_path,
|
||||
}
|
||||
for e in self._entries.values()
|
||||
],
|
||||
}
|
||||
# Atomic write: write to .tmp then rename
|
||||
tmp = self._path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
tmp.replace(self._path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProjectTrustClassification(StrEnum):
|
||||
ALWAYS_TRUSTED = "always_trusted"
|
||||
TRUSTED_BY_USER = "trusted_by_user"
|
||||
UNTRUSTED = "untrusted"
|
||||
|
||||
|
||||
class ProjectTrustDetector:
|
||||
"""Classifies a project directory as trusted or untrusted.
|
||||
|
||||
Algorithm (PRD §4.1 trust note):
|
||||
1. No project_dir → ALWAYS_TRUSTED
|
||||
2. No .git directory → ALWAYS_TRUSTED (not a git repo)
|
||||
3. No remote 'origin' → ALWAYS_TRUSTED (local-only repo)
|
||||
4. Remote URL → repo_key; in TrustedRepoStore → TRUSTED_BY_USER
|
||||
5. Localhost remote → ALWAYS_TRUSTED
|
||||
6. ~/.hive/own_remotes match → ALWAYS_TRUSTED
|
||||
7. HIVE_OWN_REMOTES env match → ALWAYS_TRUSTED
|
||||
8. None of the above → UNTRUSTED
|
||||
"""
|
||||
|
||||
def __init__(self, store: TrustedRepoStore | None = None) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
|
||||
def classify(self, project_dir: Path | None) -> tuple[ProjectTrustClassification, str]:
|
||||
"""Return (classification, repo_key).
|
||||
|
||||
repo_key is empty string for ALWAYS_TRUSTED cases without a remote.
|
||||
"""
|
||||
if project_dir is None or not project_dir.exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
if not (project_dir / ".git").exists():
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
remote_url = self._get_remote_origin(project_dir)
|
||||
if not remote_url:
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
|
||||
|
||||
repo_key = _normalize_remote_url(remote_url)
|
||||
|
||||
# Explicitly trusted by user
|
||||
if self._store.is_trusted(repo_key):
|
||||
return ProjectTrustClassification.TRUSTED_BY_USER, repo_key
|
||||
|
||||
# Localhost remotes are always trusted
|
||||
if _is_localhost_remote(remote_url):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
# User-configured own-remote patterns
|
||||
if self._matches_own_remotes(repo_key):
|
||||
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
|
||||
|
||||
return ProjectTrustClassification.UNTRUSTED, repo_key
|
||||
|
||||
def _get_remote_origin(self, project_dir: Path) -> str:
|
||||
"""Run git remote get-url origin. Returns empty string on any failure."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(project_dir), "remote", "get-url", "origin"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
"skill_trust: git remote lookup timed out for %s; treating as trusted",
|
||||
project_dir,
|
||||
)
|
||||
except (FileNotFoundError, OSError):
|
||||
pass # git not found or other OS error
|
||||
return ""
|
||||
|
||||
def _matches_own_remotes(self, repo_key: str) -> bool:
|
||||
"""Check repo_key against user-configured own-remote glob patterns."""
|
||||
import fnmatch
|
||||
|
||||
patterns: list[str] = []
|
||||
|
||||
# From env var
|
||||
env_patterns = _ENV_OWN_REMOTES
|
||||
import os
|
||||
|
||||
raw = os.environ.get(env_patterns, "")
|
||||
if raw:
|
||||
patterns.extend(p.strip() for p in raw.split(",") if p.strip())
|
||||
|
||||
# From ~/.hive/own_remotes file
|
||||
own_remotes_file = Path.home() / ".hive" / "own_remotes"
|
||||
if own_remotes_file.is_file():
|
||||
try:
|
||||
for line in own_remotes_file.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
patterns.append(line)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return any(fnmatch.fnmatch(repo_key, p) for p in patterns)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL helpers (public so CLI can reuse)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_remote_url(url: str) -> str:
|
||||
"""Normalize a git remote URL to a canonical ``host/org/repo`` key.
|
||||
|
||||
Examples:
|
||||
git@github.com:org/repo.git → github.com/org/repo
|
||||
https://github.com/org/repo → github.com/org/repo
|
||||
ssh://git@github.com/org/repo.git → github.com/org/repo
|
||||
"""
|
||||
url = url.strip()
|
||||
|
||||
# SCP-style SSH: git@github.com:org/repo.git
|
||||
if url.startswith("git@") and ":" in url and "://" not in url:
|
||||
url = url[4:] # strip git@
|
||||
url = url.replace(":", "/", 1)
|
||||
elif "://" in url:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ""
|
||||
path = parsed.path.lstrip("/")
|
||||
url = f"{host}/{path}"
|
||||
|
||||
# Strip .git suffix
|
||||
if url.endswith(".git"):
|
||||
url = url[:-4]
|
||||
|
||||
return url.lower().strip("/")
|
||||
|
||||
|
||||
def _is_localhost_remote(remote_url: str) -> bool:
|
||||
"""Return True if the remote points to a local host."""
|
||||
local_hosts = {"localhost", "127.0.0.1", "::1"}
|
||||
try:
|
||||
if "://" in remote_url:
|
||||
parsed = urlparse(remote_url)
|
||||
return (parsed.hostname or "").lower() in local_hosts
|
||||
# SCP-style: git@localhost:org/repo
|
||||
if "@" in remote_url:
|
||||
host_part = remote_url.split("@", 1)[1].split(":")[0]
|
||||
return host_part.lower() in local_hosts
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trust gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TrustGate:
|
||||
"""Filters skill list, running consent flow for untrusted project-scope skills.
|
||||
|
||||
Framework and user-scope skills are always allowed through.
|
||||
Project-scope skills from untrusted repos require consent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: TrustedRepoStore | None = None,
|
||||
detector: ProjectTrustDetector | None = None,
|
||||
interactive: bool = True,
|
||||
print_fn: Callable[[str], None] | None = None,
|
||||
input_fn: Callable[[str], str] | None = None,
|
||||
) -> None:
|
||||
self._store = store or TrustedRepoStore()
|
||||
self._detector = detector or ProjectTrustDetector(self._store)
|
||||
self._interactive = interactive
|
||||
self._print = print_fn or print
|
||||
self._input = input_fn or input
|
||||
|
||||
def filter_and_gate(
|
||||
self,
|
||||
skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
) -> list[ParsedSkill]:
|
||||
"""Return the subset of skills that are trusted for loading.
|
||||
|
||||
- Framework and user-scope skills: always included.
|
||||
- Project-scope skills: classified; consent prompt shown if untrusted.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Separate project skills from always-trusted scopes
|
||||
always_trusted = [s for s in skills if s.source_scope != "project"]
|
||||
project_skills = [s for s in skills if s.source_scope == "project"]
|
||||
|
||||
if not project_skills:
|
||||
return always_trusted
|
||||
|
||||
# Env-var CI override: trust all project skills for this invocation
|
||||
if os.environ.get(_ENV_TRUST_ALL, "").strip() == "1":
|
||||
logger.info(
|
||||
"skill_trust: %s=1 set; trusting %d project skill(s) without consent",
|
||||
_ENV_TRUST_ALL,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
classification, repo_key = self._detector.classify(project_dir)
|
||||
|
||||
if classification in (
|
||||
ProjectTrustClassification.ALWAYS_TRUSTED,
|
||||
ProjectTrustClassification.TRUSTED_BY_USER,
|
||||
):
|
||||
logger.info(
|
||||
"skill_trust: project skills trusted classification=%s repo=%s count=%d",
|
||||
classification,
|
||||
repo_key or "(no remote)",
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted + project_skills
|
||||
|
||||
# UNTRUSTED — need consent
|
||||
if not self._interactive or not sys.stdin.isatty():
|
||||
logger.warning(
|
||||
"skill_trust: skipping %d project-scope skill(s) from untrusted repo "
|
||||
"'%s' (non-interactive mode). "
|
||||
"To trust permanently run: hive skill trust %s",
|
||||
len(project_skills),
|
||||
repo_key,
|
||||
project_dir or ".",
|
||||
)
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=denied mode=headless",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
)
|
||||
return always_trusted
|
||||
|
||||
# Interactive consent flow
|
||||
decision = self._run_consent_flow(project_skills, project_dir, repo_key)
|
||||
|
||||
logger.info(
|
||||
"skill_trust_decision repo=%s skills=%d decision=%s mode=interactive",
|
||||
repo_key,
|
||||
len(project_skills),
|
||||
decision,
|
||||
)
|
||||
|
||||
if decision == "session":
|
||||
return always_trusted + project_skills
|
||||
|
||||
if decision == "permanent":
|
||||
self._store.trust(repo_key, project_path=str(project_dir or ""))
|
||||
return always_trusted + project_skills
|
||||
|
||||
# denied
|
||||
return always_trusted
|
||||
|
||||
def _run_consent_flow(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
) -> str:
|
||||
"""Show the security notice (once) and consent prompt.
|
||||
Return 'session' | 'permanent' | 'denied'."""
|
||||
from framework.credentials.setup import Colors
|
||||
|
||||
if not sys.stdout.isatty():
|
||||
Colors.disable()
|
||||
|
||||
self._maybe_show_security_notice(Colors)
|
||||
self._print_consent_prompt(project_skills, project_dir, repo_key, Colors)
|
||||
return self._prompt_consent(Colors)
|
||||
|
||||
def _maybe_show_security_notice(self, Colors) -> None: # noqa: N803
|
||||
"""Show the one-time security notice if not already shown (NFR-5)."""
|
||||
if _NOTICE_SENTINEL_PATH.exists():
|
||||
return
|
||||
self._print("")
|
||||
self._print(
|
||||
f"{Colors.YELLOW}Security notice:{Colors.NC} Skills inject instructions "
|
||||
"into the agent's system prompt."
|
||||
)
|
||||
self._print(
|
||||
" Only load skills from sources you trust. "
|
||||
"Registry skills at tier 'verified' or 'official' have been audited."
|
||||
)
|
||||
self._print("")
|
||||
try:
|
||||
_NOTICE_SENTINEL_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
_NOTICE_SENTINEL_PATH.touch()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _print_consent_prompt(
|
||||
self,
|
||||
project_skills: list[ParsedSkill],
|
||||
project_dir: Path | None,
|
||||
repo_key: str,
|
||||
Colors, # noqa: N803
|
||||
) -> None:
|
||||
p = self._print
|
||||
p("")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p(f"{Colors.BOLD} SKILL TRUST REQUIRED{Colors.NC}")
|
||||
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
|
||||
p("")
|
||||
proj_label = str(project_dir) if project_dir else "this project"
|
||||
p(
|
||||
f" The project at {Colors.CYAN}{proj_label}{Colors.NC} wants to load "
|
||||
f"{len(project_skills)} skill(s)"
|
||||
)
|
||||
p(" that will inject instructions into the agent's system prompt.")
|
||||
if repo_key:
|
||||
p(f" Source: {Colors.BOLD}{repo_key}{Colors.NC}")
|
||||
p("")
|
||||
p(" Skills requesting access:")
|
||||
for skill in project_skills:
|
||||
p(f" {Colors.CYAN}•{Colors.NC} {Colors.BOLD}{skill.name}{Colors.NC}")
|
||||
p(f' "{skill.description}"')
|
||||
p(f" {Colors.DIM}{skill.location}{Colors.NC}")
|
||||
p("")
|
||||
p(" Options:")
|
||||
p(f" {Colors.CYAN}1){Colors.NC} Trust this session only")
|
||||
p(f" {Colors.CYAN}2){Colors.NC} Trust permanently — remember for future runs")
|
||||
p(
|
||||
f" {Colors.DIM}3) Deny"
|
||||
f" — skip all project-scope skills from this repo{Colors.NC}"
|
||||
)
|
||||
p(f"{Colors.YELLOW}{'─' * 60}{Colors.NC}")
|
||||
|
||||
def _prompt_consent(self, Colors) -> str: # noqa: N803
|
||||
"""Prompt until a valid choice is entered. Returns 'session'|'permanent'|'denied'."""
|
||||
mapping = {"1": "session", "2": "permanent", "3": "denied"}
|
||||
while True:
|
||||
try:
|
||||
choice = self._input("Select option (1-3): ").strip()
|
||||
if choice in mapping:
|
||||
return mapping[choice]
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return "denied"
|
||||
self._print(f"{Colors.RED}Invalid choice. Enter 1, 2, or 3.{Colors.NC}")
|
||||
@@ -40,18 +40,31 @@ class LLMJudge:
|
||||
|
||||
def _get_fallback_provider(self) -> LLMProvider | None:
|
||||
"""
|
||||
Auto-detects available API keys and returns the appropriate provider.
|
||||
Priority: OpenAI -> Anthropic.
|
||||
Auto-detects available API keys and returns an appropriate provider.
|
||||
Uses LiteLLM for OpenAI (framework has no framework.llm.openai module).
|
||||
Priority:
|
||||
1. OpenAI-compatible models via LiteLLM (OPENAI_API_KEY)
|
||||
2. Anthropic via AnthropicProvider (ANTHROPIC_API_KEY)
|
||||
"""
|
||||
# OpenAI: use LiteLLM (the framework's standard multi-provider integration)
|
||||
if os.environ.get("OPENAI_API_KEY"):
|
||||
from framework.llm.openai import OpenAIProvider
|
||||
try:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
return OpenAIProvider(model="gpt-4o-mini")
|
||||
return LiteLLMProvider(model="gpt-4o-mini")
|
||||
except ImportError:
|
||||
# LiteLLM is optional; fall through to Anthropic/None
|
||||
pass
|
||||
|
||||
# Anthropic via dedicated provider (wraps LiteLLM internally)
|
||||
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
|
||||
return AnthropicProvider(model="claude-3-haiku-20240307")
|
||||
return AnthropicProvider(model="claude-haiku-4-5-20251001")
|
||||
except Exception:
|
||||
# If AnthropicProvider cannot be constructed, treat as no fallback
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -77,11 +90,16 @@ SUMMARY TO EVALUATE:
|
||||
Respond with JSON: {{"passes": true/false, "explanation": "..."}}"""
|
||||
|
||||
try:
|
||||
# Compute fallback provider once so we do not create multiple instances
|
||||
fallback_provider = self._get_fallback_provider()
|
||||
|
||||
# 1. Use injected provider
|
||||
if self._provider:
|
||||
active_provider = self._provider
|
||||
# 2. Check if _get_client was MOCKED (legacy tests) or use Agnostic Fallback
|
||||
elif hasattr(self._get_client, "return_value") or not self._get_fallback_provider():
|
||||
# 2. Legacy path: anthropic client mocked in tests takes precedence,
|
||||
# or no fallback provider is available.
|
||||
elif hasattr(self._get_client, "return_value") or fallback_provider is None:
|
||||
# Use legacy Anthropic client (e.g. when tests mock _get_client, or no env keys set)
|
||||
client = self._get_client()
|
||||
response = client.messages.create(
|
||||
model="claude-haiku-4-5-20251001",
|
||||
@@ -90,7 +108,8 @@ Respond with JSON: {{"passes": true/false, "explanation": "..."}}"""
|
||||
)
|
||||
return self._parse_json_result(response.content[0].text.strip())
|
||||
else:
|
||||
active_provider = self._get_fallback_provider()
|
||||
# Use env-based fallback (LiteLLM or AnthropicProvider)
|
||||
active_provider = fallback_provider
|
||||
|
||||
response = active_provider.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
|
||||
@@ -0,0 +1,374 @@
|
||||
"""Flowchart utilities for generating and persisting flowchart.json files.
|
||||
|
||||
Extracted from queen_lifecycle_tools so that non-Queen code paths
|
||||
(e.g., AgentRunner.load) can generate flowcharts for legacy agents
|
||||
that lack a flowchart.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FLOWCHART_FILENAME = "flowchart.json"
|
||||
|
||||
# ── Flowchart type catalogue (9 types) ───────────────────────────────────────
|
||||
FLOWCHART_TYPES = {
|
||||
"start": {"shape": "stadium", "color": "#8aad3f"}, # spring pollen
|
||||
"terminal": {"shape": "stadium", "color": "#b5453a"}, # propolis red
|
||||
"process": {"shape": "rectangle", "color": "#b5a575"}, # warm wheat
|
||||
"decision": {"shape": "diamond", "color": "#d89d26"}, # royal honey
|
||||
"io": {"shape": "parallelogram", "color": "#d06818"}, # burnt orange
|
||||
"document": {"shape": "document", "color": "#c4b830"}, # goldenrod
|
||||
"database": {"shape": "cylinder", "color": "#508878"}, # sage teal
|
||||
"subprocess": {"shape": "subroutine", "color": "#887a48"}, # propolis gold
|
||||
"browser": {"shape": "hexagon", "color": "#cc8850"}, # honey copper
|
||||
}
|
||||
|
||||
# Backward-compat remap: old type names → canonical type
|
||||
FLOWCHART_REMAP: dict[str, str] = {
|
||||
"delay": "process",
|
||||
"manual_operation": "process",
|
||||
"preparation": "process",
|
||||
"merge": "process",
|
||||
"alternate_process": "process",
|
||||
"connector": "process",
|
||||
"offpage_connector": "process",
|
||||
"extract": "process",
|
||||
"sort": "process",
|
||||
"collate": "process",
|
||||
"summing_junction": "process",
|
||||
"or": "process",
|
||||
"comment": "process",
|
||||
"display": "io",
|
||||
"manual_input": "io",
|
||||
"multi_document": "document",
|
||||
"stored_data": "database",
|
||||
"internal_storage": "database",
|
||||
}
|
||||
|
||||
|
||||
# ── File persistence ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def save_flowchart_file(
|
||||
agent_path: Path | str | None,
|
||||
original_draft: dict,
|
||||
flowchart_map: dict[str, list[str]] | None,
|
||||
) -> None:
|
||||
"""Persist the flowchart to the agent's folder."""
|
||||
if agent_path is None:
|
||||
return
|
||||
p = Path(agent_path)
|
||||
if not p.is_dir():
|
||||
return
|
||||
try:
|
||||
target = p / FLOWCHART_FILENAME
|
||||
target.write_text(
|
||||
json.dumps(
|
||||
{"original_draft": original_draft, "flowchart_map": flowchart_map},
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.debug("Flowchart saved to %s", target)
|
||||
except Exception:
|
||||
logger.warning("Failed to save flowchart to %s", p, exc_info=True)
|
||||
|
||||
|
||||
def load_flowchart_file(
|
||||
agent_path: Path | str | None,
|
||||
) -> tuple[dict | None, dict[str, list[str]] | None]:
|
||||
"""Load flowchart from the agent's folder. Returns (original_draft, flowchart_map)."""
|
||||
if agent_path is None:
|
||||
return None, None
|
||||
target = Path(agent_path) / FLOWCHART_FILENAME
|
||||
if not target.is_file():
|
||||
return None, None
|
||||
try:
|
||||
data = json.loads(target.read_text(encoding="utf-8"))
|
||||
return data.get("original_draft"), data.get("flowchart_map")
|
||||
except Exception:
|
||||
logger.warning("Failed to load flowchart from %s", target, exc_info=True)
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Node classification ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def classify_flowchart_node(
|
||||
node: dict,
|
||||
index: int,
|
||||
total: int,
|
||||
edges: list[dict],
|
||||
terminal_ids: set[str],
|
||||
) -> str:
|
||||
"""Auto-detect the ISO 5807 flowchart type for a draft node.
|
||||
|
||||
Priority: explicit override > structural detection > heuristic > default.
|
||||
"""
|
||||
# Explicit override from the queen
|
||||
explicit = node.get("flowchart_type", "").strip()
|
||||
if explicit and explicit in FLOWCHART_TYPES:
|
||||
return explicit
|
||||
if explicit and explicit in FLOWCHART_REMAP:
|
||||
return FLOWCHART_REMAP[explicit]
|
||||
|
||||
node_id = node["id"]
|
||||
node_type = node.get("node_type", "event_loop")
|
||||
node_tools = set(node.get("tools") or [])
|
||||
desc = (node.get("description") or "").lower()
|
||||
|
||||
# GCU / browser automation nodes → hexagon
|
||||
if node_type == "gcu":
|
||||
return "browser"
|
||||
|
||||
# Entry node (first node or no incoming edges) → start terminator
|
||||
incoming = {e["target"] for e in edges}
|
||||
if index == 0 or (node_id not in incoming and index == 0):
|
||||
return "start"
|
||||
|
||||
# Terminal node → end terminator
|
||||
if node_id in terminal_ids:
|
||||
return "terminal"
|
||||
|
||||
# Decision node: has outgoing edges with branching conditions → diamond
|
||||
outgoing = [e for e in edges if e["source"] == node_id]
|
||||
if len(outgoing) >= 2:
|
||||
conditions = {e.get("condition", "on_success") for e in outgoing}
|
||||
if len(conditions) > 1 or conditions - {"on_success"}:
|
||||
return "decision"
|
||||
|
||||
# Sub-agent / subprocess nodes → subroutine (double-bordered rect)
|
||||
if node.get("sub_agents"):
|
||||
return "subprocess"
|
||||
|
||||
# Database / data store nodes → cylinder
|
||||
db_tool_hints = {
|
||||
"query_database",
|
||||
"sql_query",
|
||||
"read_table",
|
||||
"write_table",
|
||||
"save_data",
|
||||
"load_data",
|
||||
}
|
||||
db_desc_hints = {"database", "data store", "storage", "persist", "cache"}
|
||||
if node_tools & db_tool_hints or any(h in desc for h in db_desc_hints):
|
||||
return "database"
|
||||
|
||||
# Document generation nodes → document shape
|
||||
doc_tool_hints = {
|
||||
"generate_report",
|
||||
"create_document",
|
||||
"write_report",
|
||||
"render_template",
|
||||
"export_pdf",
|
||||
}
|
||||
doc_desc_hints = {"report", "document", "summary", "write up", "writeup"}
|
||||
if node_tools & doc_tool_hints or any(h in desc for h in doc_desc_hints):
|
||||
return "document"
|
||||
|
||||
# I/O nodes: external data ingestion or delivery → parallelogram
|
||||
io_tool_hints = {
|
||||
"serve_file_to_user",
|
||||
"send_email",
|
||||
"post_message",
|
||||
"upload_file",
|
||||
"download_file",
|
||||
"fetch_url",
|
||||
"post_to_slack",
|
||||
"send_notification",
|
||||
"display_results",
|
||||
}
|
||||
io_desc_hints = {"deliver", "send", "output", "notify", "publish"}
|
||||
if node_tools & io_tool_hints or any(h in desc for h in io_desc_hints):
|
||||
return "io"
|
||||
|
||||
# Default: process (rectangle)
|
||||
return "process"
|
||||
|
||||
|
||||
# ── Draft synthesis from runtime graph ───────────────────────────────────────
|
||||
|
||||
|
||||
def synthesize_draft_from_runtime(
|
||||
runtime_nodes: list,
|
||||
runtime_edges: list,
|
||||
agent_name: str = "",
|
||||
goal_name: str = "",
|
||||
) -> tuple[dict, dict[str, list[str]]]:
|
||||
"""Generate a flowchart draft from a loaded runtime graph.
|
||||
|
||||
Used for agents that were never planned through the draft workflow
|
||||
(e.g., hand-coded or loaded from "my agents"). Produces a valid
|
||||
DraftGraph structure with auto-classified flowchart types.
|
||||
"""
|
||||
nodes: list[dict] = []
|
||||
edges: list[dict] = []
|
||||
node_ids = {n.id for n in runtime_nodes}
|
||||
|
||||
# Build edge dicts first (needed for classification)
|
||||
for i, re in enumerate(runtime_edges):
|
||||
edges.append(
|
||||
{
|
||||
"id": f"edge-{i}",
|
||||
"source": re.source,
|
||||
"target": re.target,
|
||||
"condition": str(re.condition.value)
|
||||
if hasattr(re.condition, "value")
|
||||
else str(re.condition),
|
||||
"description": getattr(re, "description", "") or "",
|
||||
"label": "",
|
||||
}
|
||||
)
|
||||
|
||||
# Terminal detection — exclude sub-agent nodes (they are leaf helpers, not endpoints)
|
||||
sub_agent_ids: set[str] = set()
|
||||
for rn in runtime_nodes:
|
||||
for sa_id in getattr(rn, "sub_agents", None) or []:
|
||||
sub_agent_ids.add(sa_id)
|
||||
sources = {e["source"] for e in edges}
|
||||
terminal_ids = node_ids - sources - sub_agent_ids
|
||||
if not terminal_ids and runtime_nodes:
|
||||
terminal_ids = {runtime_nodes[-1].id}
|
||||
|
||||
# Build node dicts with classification
|
||||
total = len(runtime_nodes)
|
||||
for i, rn in enumerate(runtime_nodes):
|
||||
node: dict = {
|
||||
"id": rn.id,
|
||||
"name": rn.name,
|
||||
"description": rn.description or "",
|
||||
"node_type": getattr(rn, "node_type", "event_loop") or "event_loop",
|
||||
"tools": list(rn.tools) if rn.tools else [],
|
||||
"input_keys": list(rn.input_keys) if rn.input_keys else [],
|
||||
"output_keys": list(rn.output_keys) if rn.output_keys else [],
|
||||
"success_criteria": getattr(rn, "success_criteria", "") or "",
|
||||
"sub_agents": list(rn.sub_agents) if getattr(rn, "sub_agents", None) else [],
|
||||
}
|
||||
fc_type = classify_flowchart_node(node, i, total, edges, terminal_ids)
|
||||
fc_meta = FLOWCHART_TYPES[fc_type]
|
||||
node["flowchart_type"] = fc_type
|
||||
node["flowchart_shape"] = fc_meta["shape"]
|
||||
node["flowchart_color"] = fc_meta["color"]
|
||||
nodes.append(node)
|
||||
|
||||
# Add visual edges from parent nodes to their sub_agents.
|
||||
# Sub-agents are connected via the sub_agents field, not via EdgeSpec,
|
||||
# so they'd appear as disconnected islands without this.
|
||||
# Two edges per sub-agent: delegate (parent→sub) and report (sub→parent).
|
||||
edge_counter = len(edges)
|
||||
for node in nodes:
|
||||
for sa_id in node.get("sub_agents") or []:
|
||||
if sa_id in node_ids:
|
||||
edges.append(
|
||||
{
|
||||
"id": f"edge-subagent-{edge_counter}",
|
||||
"source": node["id"],
|
||||
"target": sa_id,
|
||||
"condition": "always",
|
||||
"description": "sub-agent delegation",
|
||||
"label": "delegate",
|
||||
}
|
||||
)
|
||||
edge_counter += 1
|
||||
edges.append(
|
||||
{
|
||||
"id": f"edge-subagent-{edge_counter}",
|
||||
"source": sa_id,
|
||||
"target": node["id"],
|
||||
"condition": "always",
|
||||
"description": "sub-agent report back",
|
||||
"label": "report",
|
||||
}
|
||||
)
|
||||
edge_counter += 1
|
||||
|
||||
# Group sub-agent nodes under their parent in the flowchart map
|
||||
# (mirrors what _dissolve_planning_nodes does for planned drafts)
|
||||
sub_agent_ids_final: set[str] = set()
|
||||
for node in nodes:
|
||||
for sa_id in node.get("sub_agents") or []:
|
||||
if sa_id in node_ids:
|
||||
sub_agent_ids_final.add(sa_id)
|
||||
|
||||
fmap: dict[str, list[str]] = {}
|
||||
for node in nodes:
|
||||
nid = node["id"]
|
||||
if nid in sub_agent_ids_final:
|
||||
continue # skip — will be included via parent
|
||||
absorbed = [nid]
|
||||
for sa_id in node.get("sub_agents") or []:
|
||||
if sa_id in node_ids:
|
||||
absorbed.append(sa_id)
|
||||
fmap[nid] = absorbed
|
||||
|
||||
draft = {
|
||||
"agent_name": agent_name,
|
||||
"goal": goal_name,
|
||||
"description": "",
|
||||
"success_criteria": [],
|
||||
"constraints": [],
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"entry_node": nodes[0]["id"] if nodes else "",
|
||||
"terminal_nodes": sorted(terminal_ids),
|
||||
"flowchart_legend": {
|
||||
fc_type: {"shape": meta["shape"], "color": meta["color"]}
|
||||
for fc_type, meta in FLOWCHART_TYPES.items()
|
||||
},
|
||||
}
|
||||
|
||||
return draft, fmap
|
||||
|
||||
|
||||
# ── Fallback generation entry point ──────────────────────────────────────────
|
||||
|
||||
|
||||
def generate_fallback_flowchart(
|
||||
graph: Any,
|
||||
goal: Any,
|
||||
agent_path: Path,
|
||||
) -> None:
|
||||
"""Generate flowchart.json from a runtime GraphSpec if none exists.
|
||||
|
||||
This is a no-op if flowchart.json already exists. On failure, logs a
|
||||
warning but never raises — agent loading must not be blocked by
|
||||
flowchart generation.
|
||||
"""
|
||||
try:
|
||||
existing_draft, _ = load_flowchart_file(agent_path)
|
||||
if existing_draft is not None:
|
||||
return # already have one
|
||||
|
||||
draft, fmap = synthesize_draft_from_runtime(
|
||||
runtime_nodes=list(graph.nodes),
|
||||
runtime_edges=list(graph.edges),
|
||||
agent_name=agent_path.name,
|
||||
goal_name=goal.name if goal else "",
|
||||
)
|
||||
|
||||
# Enrich with Goal metadata
|
||||
if goal:
|
||||
draft["goal"] = goal.description or goal.name or ""
|
||||
draft["success_criteria"] = [sc.description for sc in (goal.success_criteria or [])]
|
||||
draft["constraints"] = [c.description for c in (goal.constraints or [])]
|
||||
|
||||
# Use entry_node/terminal_nodes from GraphSpec if available
|
||||
if graph.entry_node:
|
||||
draft["entry_node"] = graph.entry_node
|
||||
if graph.terminal_nodes:
|
||||
draft["terminal_nodes"] = list(graph.terminal_nodes)
|
||||
|
||||
save_flowchart_file(agent_path, draft, fmap)
|
||||
logger.info("Generated fallback flowchart.json for %s", agent_path.name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to generate fallback flowchart for %s",
|
||||
agent_path,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -36,6 +36,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
@@ -45,6 +46,13 @@ from framework.credentials.models import CredentialError
|
||||
from framework.runner.preload_validation import credential_errors_to_json, validate_credentials
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.server.app import validate_agent_path
|
||||
from framework.tools.flowchart_utils import (
|
||||
FLOWCHART_TYPES,
|
||||
classify_flowchart_node,
|
||||
load_flowchart_file,
|
||||
save_flowchart_file,
|
||||
synthesize_draft_from_runtime,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
@@ -108,6 +116,9 @@ class QueenPhaseState:
|
||||
prompt_staging: str = ""
|
||||
prompt_running: str = ""
|
||||
|
||||
# Default skill operational protocols — appended to every phase prompt
|
||||
protocols_prompt: str = ""
|
||||
|
||||
def get_current_tools(self) -> list:
|
||||
"""Return tools for the current phase."""
|
||||
if self.phase == "planning":
|
||||
@@ -132,7 +143,12 @@ class QueenPhaseState:
|
||||
from framework.agents.queen.queen_memory import format_for_injection
|
||||
|
||||
memory = format_for_injection()
|
||||
return base + ("\n\n" + memory if memory else "")
|
||||
parts = [base]
|
||||
if self.protocols_prompt:
|
||||
parts.append(self.protocols_prompt)
|
||||
if memory:
|
||||
parts.append(memory)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def _emit_phase_event(self) -> None:
|
||||
"""Publish a QUEEN_PHASE_CHANGED event so the frontend updates the tag."""
|
||||
@@ -285,66 +301,7 @@ def build_worker_profile(runtime: AgentRuntime, agent_path: Path | str | None =
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
_FLOWCHART_TYPES = {
|
||||
# ── Core symbols (ISO 5807 §4) ──────────────────────────
|
||||
# Terminator — rounded rectangle (stadium shape)
|
||||
"start": {"shape": "stadium", "color": "#4CAF50"}, # green
|
||||
"terminal": {"shape": "stadium", "color": "#F44336"}, # red
|
||||
# Process — rectangle
|
||||
"process": {"shape": "rectangle", "color": "#2196F3"}, # blue
|
||||
# Decision — diamond
|
||||
"decision": {"shape": "diamond", "color": "#FF9800"}, # amber
|
||||
# Data (Input/Output) — parallelogram
|
||||
"io": {"shape": "parallelogram", "color": "#9C27B0"}, # purple
|
||||
# Document — rectangle with wavy bottom
|
||||
"document": {"shape": "document", "color": "#607D8B"}, # blue-grey
|
||||
# Multi-document — stacked documents
|
||||
"multi_document": {"shape": "multi_document", "color": "#78909C"}, # blue-grey light
|
||||
# Predefined process / subroutine — rectangle with double vertical bars
|
||||
"subprocess": {"shape": "subroutine", "color": "#009688"}, # teal
|
||||
# Preparation — hexagon
|
||||
"preparation": {"shape": "hexagon", "color": "#795548"}, # brown
|
||||
# Manual input — trapezoid with slanted top
|
||||
"manual_input": {"shape": "manual_input", "color": "#E91E63"}, # pink
|
||||
# Manual operation — inverted trapezoid
|
||||
"manual_operation": {"shape": "trapezoid", "color": "#AD1457"}, # dark pink
|
||||
# Delay — half-rounded rectangle (D-shape)
|
||||
"delay": {"shape": "delay", "color": "#FF5722"}, # deep orange
|
||||
# Display — rounded rectangle with pointed left
|
||||
"display": {"shape": "display", "color": "#00BCD4"}, # cyan
|
||||
# ── Data storage symbols ────────────────────────────────
|
||||
# Database / direct access storage — cylinder
|
||||
"database": {"shape": "cylinder", "color": "#8BC34A"}, # light green
|
||||
# Stored data — generic data store
|
||||
"stored_data": {"shape": "stored_data", "color": "#CDDC39"}, # lime
|
||||
# Internal storage — rectangle with cross-hatch
|
||||
"internal_storage": {"shape": "internal_storage", "color": "#FFC107"}, # amber light
|
||||
# ── Connectors ──────────────────────────────────────────
|
||||
# On-page connector — small circle
|
||||
"connector": {"shape": "circle", "color": "#9E9E9E"}, # grey
|
||||
# Off-page connector — pentagon / home-plate
|
||||
"offpage_connector": {"shape": "pentagon", "color": "#757575"}, # dark grey
|
||||
# ── Flow operations ─────────────────────────────────────
|
||||
# Merge — inverted triangle
|
||||
"merge": {"shape": "triangle_inv", "color": "#3F51B5"}, # indigo
|
||||
# Extract — upward triangle
|
||||
"extract": {"shape": "triangle", "color": "#5C6BC0"}, # indigo light
|
||||
# Sort — hourglass / double triangle
|
||||
"sort": {"shape": "hourglass", "color": "#7986CB"}, # indigo lighter
|
||||
# Collate — merged hourglass
|
||||
"collate": {"shape": "hourglass_inv", "color": "#9FA8DA"}, # indigo lightest
|
||||
# Summing junction — circle with cross
|
||||
"summing_junction": {"shape": "circle_cross", "color": "#F06292"}, # pink light
|
||||
# Or — circle with horizontal bar
|
||||
"or": {"shape": "circle_bar", "color": "#CE93D8"}, # purple light
|
||||
# ── Domain-specific (Hive agent context) ────────────────
|
||||
# Browser automation (GCU) — mapped to preparation/hexagon
|
||||
"browser": {"shape": "hexagon", "color": "#1A237E"}, # dark indigo
|
||||
# Comment / annotation — flag shape
|
||||
"comment": {"shape": "flag", "color": "#BDBDBD"}, # light grey
|
||||
# Alternate process — rounded rectangle
|
||||
"alternate_process": {"shape": "rounded_rect", "color": "#42A5F5"}, # light blue
|
||||
}
|
||||
# FLOWCHART_TYPES is imported from framework.tools.flowchart_utils
|
||||
|
||||
|
||||
def _read_agent_triggers_json(agent_path: Path) -> list[dict]:
|
||||
@@ -451,10 +408,11 @@ async def _start_trigger_timer(session: Any, trigger_id: str, tdef: Any) -> None
|
||||
else:
|
||||
await asyncio.sleep(float(interval_minutes) * 60)
|
||||
|
||||
# Record next fire time for introspection
|
||||
# Record next fire time for introspection (monotonic, matches routes)
|
||||
fire_times = getattr(session, "trigger_next_fire", None)
|
||||
if fire_times is not None:
|
||||
fire_times[trigger_id] = datetime.now(tz=UTC).isoformat()
|
||||
_next_delay = float(interval_minutes) * 60 if interval_minutes else 60
|
||||
fire_times[trigger_id] = time.monotonic() + _next_delay
|
||||
|
||||
# Gate on worker being loaded
|
||||
if getattr(session, "worker_runtime", None) is None:
|
||||
@@ -635,7 +593,7 @@ def _dissolve_planning_nodes(
|
||||
if not predecessors:
|
||||
# Decision at start: convert to regular process node
|
||||
d_node["flowchart_type"] = "process"
|
||||
fc_meta = _FLOWCHART_TYPES["process"]
|
||||
fc_meta = FLOWCHART_TYPES["process"]
|
||||
d_node["flowchart_shape"] = fc_meta["shape"]
|
||||
d_node["flowchart_color"] = fc_meta["color"]
|
||||
if not d_node.get("success_criteria"):
|
||||
@@ -769,6 +727,25 @@ def _dissolve_planning_nodes(
|
||||
return converted, flowchart_map
|
||||
|
||||
|
||||
def _update_meta_json(session_manager, manager_session_id, updates: dict) -> None:
|
||||
"""Merge updates into the queen session's meta.json."""
|
||||
if session_manager is None or not manager_session_id:
|
||||
return
|
||||
srv_session = session_manager.get_session(manager_session_id)
|
||||
if not srv_session:
|
||||
return
|
||||
storage_sid = getattr(srv_session, "queen_resume_from", None) or srv_session.id
|
||||
meta_path = Path.home() / ".hive" / "queen" / "session" / storage_sid / "meta.json"
|
||||
try:
|
||||
existing = {}
|
||||
if meta_path.exists():
|
||||
existing = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
existing.update(updates)
|
||||
meta_path.write_text(json.dumps(existing), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def register_queen_lifecycle_tools(
|
||||
registry: ToolRegistry,
|
||||
session: Any = None,
|
||||
@@ -1017,6 +994,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to building phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_building()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "building"
|
||||
@@ -1147,309 +1125,20 @@ def register_queen_lifecycle_tools(
|
||||
registry.register("replan_agent", _replan_tool, lambda inputs: replan_agent())
|
||||
tools_registered += 1
|
||||
|
||||
# --- Flowchart file persistence -------------------------------------------
|
||||
# The flowchart is saved as flowchart.json in the agent's folder so it
|
||||
# survives restarts and is available when loading any agent.
|
||||
|
||||
FLOWCHART_FILENAME = "flowchart.json"
|
||||
|
||||
def _save_flowchart_file(
|
||||
agent_path: Path | str | None,
|
||||
original_draft: dict,
|
||||
flowchart_map: dict[str, list[str]] | None,
|
||||
) -> None:
|
||||
"""Persist the flowchart to the agent's folder."""
|
||||
if agent_path is None:
|
||||
return
|
||||
p = Path(agent_path)
|
||||
if not p.is_dir():
|
||||
return
|
||||
try:
|
||||
target = p / FLOWCHART_FILENAME
|
||||
target.write_text(
|
||||
json.dumps(
|
||||
{"original_draft": original_draft, "flowchart_map": flowchart_map},
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.debug("Flowchart saved to %s", target)
|
||||
except Exception:
|
||||
logger.warning("Failed to save flowchart to %s", p, exc_info=True)
|
||||
|
||||
def _load_flowchart_file(
|
||||
agent_path: Path | str | None,
|
||||
) -> tuple[dict | None, dict[str, list[str]] | None]:
|
||||
"""Load flowchart from the agent's folder. Returns (original_draft, flowchart_map)."""
|
||||
if agent_path is None:
|
||||
return None, None
|
||||
target = Path(agent_path) / FLOWCHART_FILENAME
|
||||
if not target.is_file():
|
||||
return None, None
|
||||
try:
|
||||
data = json.loads(target.read_text(encoding="utf-8"))
|
||||
return data.get("original_draft"), data.get("flowchart_map")
|
||||
except Exception:
|
||||
logger.warning("Failed to load flowchart from %s", target, exc_info=True)
|
||||
return None, None
|
||||
|
||||
def _synthesize_draft_from_runtime(
|
||||
runtime_nodes: list,
|
||||
runtime_edges: list,
|
||||
agent_name: str = "",
|
||||
goal_name: str = "",
|
||||
) -> tuple[dict, dict[str, list[str]]]:
|
||||
"""Generate a flowchart draft from a loaded runtime graph.
|
||||
|
||||
Used for agents that were never planned through the draft workflow
|
||||
(e.g., hand-coded or loaded from "my agents"). Produces a valid
|
||||
DraftGraph structure with auto-classified flowchart types.
|
||||
"""
|
||||
nodes: list[dict] = []
|
||||
edges: list[dict] = []
|
||||
node_ids = {n.id for n in runtime_nodes}
|
||||
|
||||
# Build edge dicts first (needed for classification)
|
||||
for i, re in enumerate(runtime_edges):
|
||||
edges.append(
|
||||
{
|
||||
"id": f"edge-{i}",
|
||||
"source": re.source,
|
||||
"target": re.target,
|
||||
"condition": str(re.condition.value)
|
||||
if hasattr(re.condition, "value")
|
||||
else str(re.condition),
|
||||
"description": getattr(re, "description", "") or "",
|
||||
"label": "",
|
||||
}
|
||||
)
|
||||
|
||||
# Terminal detection — exclude sub-agent nodes (they are leaf helpers, not endpoints)
|
||||
sub_agent_ids: set[str] = set()
|
||||
for rn in runtime_nodes:
|
||||
for sa_id in getattr(rn, "sub_agents", None) or []:
|
||||
sub_agent_ids.add(sa_id)
|
||||
sources = {e["source"] for e in edges}
|
||||
terminal_ids = node_ids - sources - sub_agent_ids
|
||||
if not terminal_ids and runtime_nodes:
|
||||
terminal_ids = {runtime_nodes[-1].id}
|
||||
|
||||
# Build node dicts with classification
|
||||
total = len(runtime_nodes)
|
||||
for i, rn in enumerate(runtime_nodes):
|
||||
node: dict = {
|
||||
"id": rn.id,
|
||||
"name": rn.name,
|
||||
"description": rn.description or "",
|
||||
"node_type": getattr(rn, "node_type", "event_loop") or "event_loop",
|
||||
"tools": list(rn.tools) if rn.tools else [],
|
||||
"input_keys": list(rn.input_keys) if rn.input_keys else [],
|
||||
"output_keys": list(rn.output_keys) if rn.output_keys else [],
|
||||
"success_criteria": getattr(rn, "success_criteria", "") or "",
|
||||
"sub_agents": list(rn.sub_agents) if getattr(rn, "sub_agents", None) else [],
|
||||
}
|
||||
fc_type = _classify_flowchart_node(node, i, total, edges, terminal_ids)
|
||||
fc_meta = _FLOWCHART_TYPES[fc_type]
|
||||
node["flowchart_type"] = fc_type
|
||||
node["flowchart_shape"] = fc_meta["shape"]
|
||||
node["flowchart_color"] = fc_meta["color"]
|
||||
nodes.append(node)
|
||||
|
||||
# Add visual edges from parent nodes to their sub_agents.
|
||||
# Sub-agents are connected via the sub_agents field, not via EdgeSpec,
|
||||
# so they'd appear as disconnected islands without this.
|
||||
# Two edges per sub-agent: delegate (parent→sub) and report (sub→parent).
|
||||
edge_counter = len(edges)
|
||||
for node in nodes:
|
||||
for sa_id in node.get("sub_agents") or []:
|
||||
if sa_id in node_ids:
|
||||
edges.append(
|
||||
{
|
||||
"id": f"edge-subagent-{edge_counter}",
|
||||
"source": node["id"],
|
||||
"target": sa_id,
|
||||
"condition": "always",
|
||||
"description": "sub-agent delegation",
|
||||
"label": "delegate",
|
||||
}
|
||||
)
|
||||
edge_counter += 1
|
||||
edges.append(
|
||||
{
|
||||
"id": f"edge-subagent-{edge_counter}",
|
||||
"source": sa_id,
|
||||
"target": node["id"],
|
||||
"condition": "always",
|
||||
"description": "sub-agent report back",
|
||||
"label": "report",
|
||||
}
|
||||
)
|
||||
edge_counter += 1
|
||||
|
||||
# Group sub-agent nodes under their parent in the flowchart map
|
||||
# (mirrors what _dissolve_planning_nodes does for planned drafts)
|
||||
sub_agent_ids: set[str] = set()
|
||||
for node in nodes:
|
||||
for sa_id in node.get("sub_agents") or []:
|
||||
if sa_id in node_ids:
|
||||
sub_agent_ids.add(sa_id)
|
||||
|
||||
fmap: dict[str, list[str]] = {}
|
||||
for node in nodes:
|
||||
nid = node["id"]
|
||||
if nid in sub_agent_ids:
|
||||
continue # skip — will be included via parent
|
||||
absorbed = [nid]
|
||||
for sa_id in node.get("sub_agents") or []:
|
||||
if sa_id in node_ids:
|
||||
absorbed.append(sa_id)
|
||||
fmap[nid] = absorbed
|
||||
|
||||
draft = {
|
||||
"agent_name": agent_name,
|
||||
"goal": goal_name,
|
||||
"description": "",
|
||||
"success_criteria": [],
|
||||
"constraints": [],
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"entry_node": nodes[0]["id"] if nodes else "",
|
||||
"terminal_nodes": sorted(terminal_ids),
|
||||
"flowchart_legend": {
|
||||
fc_type: {"shape": meta["shape"], "color": meta["color"]}
|
||||
for fc_type, meta in _FLOWCHART_TYPES.items()
|
||||
},
|
||||
}
|
||||
|
||||
return draft, fmap
|
||||
# --- Flowchart utilities ---------------------------------------------------
|
||||
# Flowchart persistence, classification, and synthesis functions are now in
|
||||
# framework.tools.flowchart_utils. Local aliases for backward compatibility
|
||||
# within this closure:
|
||||
_save_flowchart_file = save_flowchart_file
|
||||
_load_flowchart_file = load_flowchart_file
|
||||
_synthesize_draft_from_runtime = synthesize_draft_from_runtime
|
||||
_classify_flowchart_node = classify_flowchart_node
|
||||
|
||||
# --- save_agent_draft (Planning phase — declarative graph preview) ---------
|
||||
# Creates a lightweight draft graph with nodes, edges, and business metadata.
|
||||
# Loose validation: only requires names and descriptions. Emits an event
|
||||
# so the frontend can render the graph during planning (before any code).
|
||||
|
||||
def _classify_flowchart_node(
|
||||
node: dict,
|
||||
index: int,
|
||||
total: int,
|
||||
edges: list[dict],
|
||||
terminal_ids: set[str],
|
||||
) -> str:
|
||||
"""Auto-detect the ISO 5807 flowchart type for a draft node.
|
||||
|
||||
Priority: explicit override > structural detection > heuristic > default.
|
||||
"""
|
||||
# Explicit override from the queen
|
||||
explicit = node.get("flowchart_type", "").strip()
|
||||
if explicit and explicit in _FLOWCHART_TYPES:
|
||||
return explicit
|
||||
|
||||
node_id = node["id"]
|
||||
node_type = node.get("node_type", "event_loop")
|
||||
node_tools = set(node.get("tools") or [])
|
||||
desc = (node.get("description") or "").lower()
|
||||
name = (node.get("name") or "").lower()
|
||||
|
||||
# GCU / browser automation nodes → hexagon
|
||||
if node_type == "gcu":
|
||||
return "browser"
|
||||
|
||||
# Entry node (first node or no incoming edges) → start terminator
|
||||
incoming = {e["target"] for e in edges}
|
||||
if index == 0 or (node_id not in incoming and index == 0):
|
||||
return "start"
|
||||
|
||||
# Terminal node → end terminator
|
||||
if node_id in terminal_ids:
|
||||
return "terminal"
|
||||
|
||||
# Decision node: has outgoing edges with branching conditions → diamond
|
||||
outgoing = [e for e in edges if e["source"] == node_id]
|
||||
if len(outgoing) >= 2:
|
||||
conditions = {e.get("condition", "on_success") for e in outgoing}
|
||||
if len(conditions) > 1 or conditions - {"on_success"}:
|
||||
return "decision"
|
||||
|
||||
# Sub-agent / subprocess nodes → subroutine (double-bordered rect)
|
||||
if node.get("sub_agents"):
|
||||
return "subprocess"
|
||||
|
||||
# Database / data store nodes → cylinder
|
||||
db_tool_hints = {
|
||||
"query_database",
|
||||
"sql_query",
|
||||
"read_table",
|
||||
"write_table",
|
||||
"save_data",
|
||||
"load_data",
|
||||
}
|
||||
db_desc_hints = {"database", "data store", "storage", "persist", "cache"}
|
||||
if node_tools & db_tool_hints or any(h in desc for h in db_desc_hints):
|
||||
return "database"
|
||||
|
||||
# Document generation nodes → document shape
|
||||
doc_tool_hints = {
|
||||
"generate_report",
|
||||
"create_document",
|
||||
"write_report",
|
||||
"render_template",
|
||||
"export_pdf",
|
||||
}
|
||||
doc_desc_hints = {"report", "document", "summary", "write up", "writeup"}
|
||||
if node_tools & doc_tool_hints or any(h in desc for h in doc_desc_hints):
|
||||
return "document"
|
||||
|
||||
# I/O nodes: external data ingestion or delivery → parallelogram
|
||||
io_tool_hints = {
|
||||
"serve_file_to_user",
|
||||
"send_email",
|
||||
"post_message",
|
||||
"upload_file",
|
||||
"download_file",
|
||||
"fetch_url",
|
||||
"post_to_slack",
|
||||
"send_notification",
|
||||
}
|
||||
io_desc_hints = {"deliver", "send", "output", "notify", "publish"}
|
||||
if node_tools & io_tool_hints or any(h in desc for h in io_desc_hints):
|
||||
return "io"
|
||||
|
||||
# Manual / human-in-the-loop nodes → trapezoid
|
||||
manual_desc_hints = {
|
||||
"human review",
|
||||
"manual",
|
||||
"approval",
|
||||
"human-in-the-loop",
|
||||
"user review",
|
||||
"manual check",
|
||||
}
|
||||
if any(h in desc for h in manual_desc_hints) or any(h in name for h in manual_desc_hints):
|
||||
return "manual_operation"
|
||||
|
||||
# Preparation / setup nodes → hexagon
|
||||
prep_desc_hints = {"setup", "initialize", "prepare", "configure", "provision"}
|
||||
if any(h in desc for h in prep_desc_hints) or any(h in name for h in prep_desc_hints):
|
||||
return "preparation"
|
||||
|
||||
# Delay / wait nodes → D-shape
|
||||
delay_desc_hints = {"wait", "delay", "pause", "cooldown", "throttle", "sleep"}
|
||||
if any(h in desc for h in delay_desc_hints):
|
||||
return "delay"
|
||||
|
||||
# Merge nodes → inverted triangle
|
||||
merge_desc_hints = {"merge", "combine", "aggregate", "consolidate"}
|
||||
if any(h in desc for h in merge_desc_hints) or any(h in name for h in merge_desc_hints):
|
||||
return "merge"
|
||||
|
||||
# Display nodes → display shape
|
||||
display_desc_hints = {"display", "show", "present", "render", "visualize"}
|
||||
display_tool_hints = {"serve_file_to_user", "display_results"}
|
||||
if node_tools & display_tool_hints or any(h in name for h in display_desc_hints):
|
||||
return "display"
|
||||
|
||||
# Default: process (rectangle)
|
||||
return "process"
|
||||
|
||||
def _dissolve_planning_nodes(
|
||||
draft: dict,
|
||||
) -> tuple[dict, dict[str, list[str]]]:
|
||||
@@ -1535,7 +1224,7 @@ def register_queen_lifecycle_tools(
|
||||
if not predecessors:
|
||||
# Decision at start: convert to regular process node
|
||||
d_node["flowchart_type"] = "process"
|
||||
fc_meta = _FLOWCHART_TYPES["process"]
|
||||
fc_meta = FLOWCHART_TYPES["process"]
|
||||
d_node["flowchart_shape"] = fc_meta["shape"]
|
||||
d_node["flowchart_color"] = fc_meta["color"]
|
||||
if not d_node.get("success_criteria"):
|
||||
@@ -1890,12 +1579,22 @@ def register_queen_lifecycle_tools(
|
||||
# Find edges where this leaf node is the source
|
||||
out_edges = [e for e in validated_edges if e["source"] == leaf_id]
|
||||
in_edges = [e for e in validated_edges if e["target"] == leaf_id]
|
||||
if not out_edges:
|
||||
continue # already a proper leaf
|
||||
|
||||
# Identify the parent (predecessor that connects IN)
|
||||
parent_ids = [e["source"] for e in in_edges]
|
||||
|
||||
if not out_edges:
|
||||
# Already a proper leaf — still ensure sub_agents is set
|
||||
for pid in parent_ids:
|
||||
parent = node_by_id_v.get(pid)
|
||||
if parent is None:
|
||||
continue
|
||||
existing = parent.get("sub_agents") or []
|
||||
if leaf_id not in existing:
|
||||
existing.append(leaf_id)
|
||||
parent["sub_agents"] = existing
|
||||
continue
|
||||
|
||||
# Strip all outgoing edges from the leaf node that
|
||||
# don't go back to a parent (report edges are OK)
|
||||
illegal_targets: list[str] = []
|
||||
@@ -2087,7 +1786,7 @@ def register_queen_lifecycle_tools(
|
||||
validated_edges,
|
||||
terminal_ids,
|
||||
)
|
||||
fc_meta = _FLOWCHART_TYPES[fc_type]
|
||||
fc_meta = FLOWCHART_TYPES[fc_type]
|
||||
node["flowchart_type"] = fc_type
|
||||
node["flowchart_shape"] = fc_meta["shape"]
|
||||
node["flowchart_color"] = fc_meta["color"]
|
||||
@@ -2105,7 +1804,7 @@ def register_queen_lifecycle_tools(
|
||||
# Color legend for the frontend
|
||||
"flowchart_legend": {
|
||||
fc_type: {"shape": meta["shape"], "color": meta["color"]}
|
||||
for fc_type, meta in _FLOWCHART_TYPES.items()
|
||||
for fc_type, meta in FLOWCHART_TYPES.items()
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2276,39 +1975,18 @@ def register_queen_lifecycle_tools(
|
||||
"decision",
|
||||
"io",
|
||||
"document",
|
||||
"multi_document",
|
||||
"subprocess",
|
||||
"preparation",
|
||||
"manual_input",
|
||||
"manual_operation",
|
||||
"delay",
|
||||
"display",
|
||||
"database",
|
||||
"stored_data",
|
||||
"internal_storage",
|
||||
"connector",
|
||||
"offpage_connector",
|
||||
"merge",
|
||||
"extract",
|
||||
"sort",
|
||||
"collate",
|
||||
"summing_junction",
|
||||
"or",
|
||||
"subprocess",
|
||||
"browser",
|
||||
"comment",
|
||||
"alternate_process",
|
||||
],
|
||||
"description": (
|
||||
"ISO 5807 flowchart symbol type. Auto-detected if omitted. "
|
||||
"Core: start (green stadium), terminal (red stadium), "
|
||||
"process (blue rect), decision (amber diamond), "
|
||||
"io (purple parallelogram), document (grey wavy rect), "
|
||||
"subprocess (teal subroutine), preparation (brown hexagon), "
|
||||
"manual_operation (pink trapezoid), delay (orange D-shape), "
|
||||
"display (cyan), database (green cylinder), "
|
||||
"merge (indigo triangle), browser (dark indigo hexagon — "
|
||||
"for GCU/browser sub-agents; must be a leaf node connected "
|
||||
"only to its managing parent)"
|
||||
"Flowchart symbol type. Auto-detected if omitted. "
|
||||
"start (sage green stadium), terminal (dusty red stadium), "
|
||||
"process (blue-gray rect), decision (amber diamond), "
|
||||
"io (purple parallelogram), document (steel blue wavy rect), "
|
||||
"database (teal cylinder), subprocess (cyan subroutine), "
|
||||
"browser (deep blue hexagon — for GCU/browser "
|
||||
"sub-agents; must be a leaf node)"
|
||||
),
|
||||
},
|
||||
"tools": {
|
||||
@@ -2330,6 +2008,17 @@ def register_queen_lifecycle_tools(
|
||||
"type": "string",
|
||||
"description": "What success looks like for this node",
|
||||
},
|
||||
"sub_agents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"IDs of GCU/browser sub-agent nodes managed by this node. "
|
||||
"At build time, sub-agent nodes are dissolved into this list. "
|
||||
"Set this on the PARENT node — e.g. the orchestrator that "
|
||||
"delegates to GCU leaves. Visual delegation edges are "
|
||||
"synthesized automatically."
|
||||
),
|
||||
},
|
||||
"decision_clause": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
@@ -2447,8 +2136,22 @@ def register_queen_lifecycle_tools(
|
||||
phase_state.draft_graph = converted
|
||||
phase_state.flowchart_map = fmap
|
||||
|
||||
# Note: flowchart file is persisted later, in initialize_and_build_agent
|
||||
# (after the agent folder is scaffolded) or in load_built_agent.
|
||||
# Create agent folder early so flowchart and agent_path are available
|
||||
# throughout the entire BUILDING phase.
|
||||
_agent_name = phase_state.draft_graph.get("agent_name", "").strip()
|
||||
if _agent_name:
|
||||
_agent_folder = Path("exports") / _agent_name
|
||||
_agent_folder.mkdir(parents=True, exist_ok=True)
|
||||
_save_flowchart_file(_agent_folder, original_copy, fmap)
|
||||
phase_state.agent_path = str(_agent_folder)
|
||||
_update_meta_json(
|
||||
session_manager,
|
||||
manager_session_id,
|
||||
{
|
||||
"agent_path": str(_agent_folder),
|
||||
"agent_name": _agent_name.replace("_", " ").title(),
|
||||
},
|
||||
)
|
||||
|
||||
dissolved_count = len(original_nodes) - len(converted.get("nodes", []))
|
||||
decision_count = sum(1 for n in original_nodes if n.get("flowchart_type") == "decision")
|
||||
@@ -2580,6 +2283,7 @@ def register_queen_lifecycle_tools(
|
||||
if fallback_path:
|
||||
phase_state.agent_path = str(fallback_path)
|
||||
await phase_state.switch_to_building(source="tool")
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
|
||||
if phase_state.inject_notification:
|
||||
await phase_state.inject_notification(
|
||||
"[PHASE CHANGE] Switched to BUILDING phase. "
|
||||
@@ -2622,8 +2326,13 @@ def register_queen_lifecycle_tools(
|
||||
if parsed.get("success", True):
|
||||
if phase_state is not None:
|
||||
# Set agent_path so the frontend can query credentials
|
||||
phase_state.agent_path = str(Path("exports") / agent_name)
|
||||
phase_state.agent_path = phase_state.agent_path or str(
|
||||
Path("exports") / agent_name
|
||||
)
|
||||
await phase_state.switch_to_building(source="tool")
|
||||
_update_meta_json(
|
||||
session_manager, manager_session_id, {"phase": "building"}
|
||||
)
|
||||
# Reset draft state after successful scaffolding
|
||||
phase_state.build_confirmed = False
|
||||
# Persist flowchart now that the agent folder exists
|
||||
@@ -2671,6 +2380,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to staging phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_staging()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "staging"
|
||||
@@ -2699,6 +2409,30 @@ def register_queen_lifecycle_tools(
|
||||
"""Get the session's event bus for querying history."""
|
||||
return getattr(session, "event_bus", None)
|
||||
|
||||
def _get_worker_name() -> str | None:
|
||||
"""Return the worker agent directory name, used for diary lookups."""
|
||||
p = getattr(session, "worker_path", None)
|
||||
return p.name if p else None
|
||||
|
||||
def _format_diary(max_runs: int) -> str:
|
||||
"""Read recent run digests from disk — no EventBus required."""
|
||||
agent_name = _get_worker_name()
|
||||
if not agent_name:
|
||||
return "No worker loaded — diary unavailable."
|
||||
from framework.agents.worker_memory import read_recent_digests
|
||||
|
||||
entries = read_recent_digests(agent_name, max_runs)
|
||||
if not entries:
|
||||
return (
|
||||
f"No run digests for '{agent_name}' yet. "
|
||||
"Digests are written at the end of each completed run."
|
||||
)
|
||||
lines = [f"Worker '{agent_name}' — {len(entries)} recent run digest(s):", ""]
|
||||
for _run_id, content in entries:
|
||||
lines.append(content)
|
||||
lines.append("")
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
# Tiered cooldowns: summary is free, detail has short cooldown, full keeps 30s
|
||||
_COOLDOWN_FULL = 30.0
|
||||
_COOLDOWN_DETAIL = 10.0
|
||||
@@ -3301,16 +3035,17 @@ def register_queen_lifecycle_tools(
|
||||
import time as _time
|
||||
|
||||
# --- Tiered cooldown ---
|
||||
# diary is free (file reads only), summary is free, detail has 10s, full has 30s
|
||||
now = _time.monotonic()
|
||||
if focus == "full":
|
||||
cooldown = _COOLDOWN_FULL
|
||||
tier = "full"
|
||||
elif focus is not None:
|
||||
elif focus == "diary" or focus is None:
|
||||
cooldown = 0.0
|
||||
tier = focus or "summary"
|
||||
else:
|
||||
cooldown = _COOLDOWN_DETAIL
|
||||
tier = "detail"
|
||||
else:
|
||||
cooldown = 0.0
|
||||
tier = "summary"
|
||||
|
||||
elapsed_since = now - _status_last_called.get(tier, 0.0)
|
||||
if elapsed_since < cooldown:
|
||||
@@ -3326,6 +3061,10 @@ def register_queen_lifecycle_tools(
|
||||
)
|
||||
_status_last_called[tier] = now
|
||||
|
||||
# --- Diary: pure file reads, no runtime required ---
|
||||
if focus == "diary":
|
||||
return _format_diary(last_n)
|
||||
|
||||
# --- Runtime check ---
|
||||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
@@ -3375,7 +3114,7 @@ def register_queen_lifecycle_tools(
|
||||
else:
|
||||
return (
|
||||
f"Unknown focus '{focus}'. "
|
||||
"Valid options: activity, memory, tools, issues, progress, full."
|
||||
"Valid options: diary, activity, memory, tools, issues, progress, full."
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("get_worker_status error")
|
||||
@@ -3386,6 +3125,8 @@ def register_queen_lifecycle_tools(
|
||||
description=(
|
||||
"Check on the worker. Returns a brief prose summary by default. "
|
||||
"Use 'focus' to drill into specifics:\n"
|
||||
"- diary: persistent run digests from past executions — read this first "
|
||||
"before digging into live runtime logs\n"
|
||||
"- activity: current node, transitions, latest LLM output\n"
|
||||
"- memory: worker's accumulated knowledge and state\n"
|
||||
"- tools: running and recent tool calls\n"
|
||||
@@ -3398,8 +3139,11 @@ def register_queen_lifecycle_tools(
|
||||
"properties": {
|
||||
"focus": {
|
||||
"type": "string",
|
||||
"enum": ["activity", "memory", "tools", "issues", "progress", "full"],
|
||||
"description": ("Aspect to inspect. Omit for a brief summary."),
|
||||
"enum": ["diary", "activity", "memory", "tools", "issues", "progress", "full"],
|
||||
"description": (
|
||||
"Aspect to inspect. Omit for a brief summary. "
|
||||
"Use 'diary' to read persistent run history before checking live logs."
|
||||
),
|
||||
},
|
||||
"last_n": {
|
||||
"type": "integer",
|
||||
@@ -3798,6 +3542,7 @@ def register_queen_lifecycle_tools(
|
||||
if phase_state is not None:
|
||||
phase_state.agent_path = str(resolved_path)
|
||||
await phase_state.switch_to_staging()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
|
||||
|
||||
worker_name = info.name if info else updated_session.worker_id
|
||||
return json.dumps(
|
||||
@@ -3917,6 +3662,7 @@ def register_queen_lifecycle_tools(
|
||||
# Switch to running phase
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_running()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "running"})
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
@@ -4054,6 +3800,8 @@ def register_queen_lifecycle_tools(
|
||||
_save_trigger_to_agent(session, trigger_id, tdef)
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus:
|
||||
_runner = getattr(session, "runner", None)
|
||||
_graph_entry = _runner.graph.entry_node if _runner else None
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_ACTIVATED,
|
||||
@@ -4062,6 +3810,8 @@ def register_queen_lifecycle_tools(
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": t_type,
|
||||
"trigger_config": t_config,
|
||||
"name": tdef.description or trigger_id,
|
||||
**({"entry_node": _graph_entry} if _graph_entry else {}),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -4114,6 +3864,8 @@ def register_queen_lifecycle_tools(
|
||||
# Emit event
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus:
|
||||
_runner = getattr(session, "runner", None)
|
||||
_graph_entry = _runner.graph.entry_node if _runner else None
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_ACTIVATED,
|
||||
@@ -4122,6 +3874,8 @@ def register_queen_lifecycle_tools(
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": t_type,
|
||||
"trigger_config": t_config,
|
||||
"name": tdef.description or trigger_id,
|
||||
**({"entry_node": _graph_entry} if _graph_entry else {}),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -4220,7 +3974,10 @@ def register_queen_lifecycle_tools(
|
||||
AgentEvent(
|
||||
type=EventType.TRIGGER_DEACTIVATED,
|
||||
stream_id="queen",
|
||||
data={"trigger_id": trigger_id},
|
||||
data={
|
||||
"trigger_id": trigger_id,
|
||||
"name": tdef.description or trigger_id if tdef else trigger_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Tool for the queen to write to her episodic memory.
|
||||
"""Tools for the queen to read and write episodic memory.
|
||||
|
||||
The queen can consciously record significant moments during a session — like
|
||||
writing in a diary. Semantic memory (MEMORY.md) is updated automatically at
|
||||
session end and is never written by the queen directly.
|
||||
writing in a diary — and recall past diary entries when needed. Semantic
|
||||
memory (MEMORY.md) is updated automatically at session end and is never
|
||||
written by the queen directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -33,6 +34,67 @@ def write_to_diary(entry: str) -> str:
|
||||
return "Diary entry recorded."
|
||||
|
||||
|
||||
def recall_diary(query: str = "", days_back: int = 7) -> str:
|
||||
"""Search recent diary entries (episodic memory).
|
||||
|
||||
Use this when the user asks about what happened in the past — "what did we
|
||||
do yesterday?", "what happened last week?", "remind me about the pipeline
|
||||
issue", etc. Also use it proactively when you need context from recent
|
||||
sessions to answer a question or make a decision.
|
||||
|
||||
Args:
|
||||
query: Optional keyword or phrase to filter entries. If empty, all
|
||||
recent entries are returned.
|
||||
days_back: How many days to look back (1–30). Defaults to 7.
|
||||
"""
|
||||
from datetime import date, timedelta
|
||||
|
||||
from framework.agents.queen.queen_memory import read_episodic_memory
|
||||
|
||||
days_back = max(1, min(days_back, 30))
|
||||
today = date.today()
|
||||
results: list[str] = []
|
||||
total_chars = 0
|
||||
char_budget = 12_000
|
||||
|
||||
for offset in range(days_back):
|
||||
d = today - timedelta(days=offset)
|
||||
content = read_episodic_memory(d)
|
||||
if not content:
|
||||
continue
|
||||
# If a query is given, only include entries that mention it
|
||||
if query:
|
||||
# Check each section (split by ###) for relevance
|
||||
sections = content.split("### ")
|
||||
matched = [s for s in sections if query.lower() in s.lower()]
|
||||
if not matched:
|
||||
continue
|
||||
content = "### ".join(matched)
|
||||
label = d.strftime("%B %-d, %Y")
|
||||
if d == today:
|
||||
label = f"Today — {label}"
|
||||
entry = f"## {label}\n\n{content}"
|
||||
if total_chars + len(entry) > char_budget:
|
||||
remaining = char_budget - total_chars
|
||||
if remaining > 200:
|
||||
# Fit a partial entry within budget
|
||||
trimmed = content[: remaining - 100] + "\n\n…(truncated)"
|
||||
results.append(f"## {label}\n\n{trimmed}")
|
||||
else:
|
||||
results.append(f"## {label}\n\n(truncated — hit size limit)")
|
||||
break
|
||||
results.append(entry)
|
||||
total_chars += len(entry)
|
||||
|
||||
if not results:
|
||||
if query:
|
||||
return f"No diary entries matching '{query}' in the last {days_back} days."
|
||||
return f"No diary entries found in the last {days_back} days."
|
||||
|
||||
return "\n\n---\n\n".join(results)
|
||||
|
||||
|
||||
def register_queen_memory_tools(registry: ToolRegistry) -> None:
|
||||
"""Register the episodic memory tool into the queen's tool registry."""
|
||||
"""Register the episodic memory tools into the queen's tool registry."""
|
||||
registry.register_function(write_to_diary)
|
||||
registry.register_function(recall_diary)
|
||||
|
||||
@@ -64,10 +64,14 @@ export const sessionsApi = {
|
||||
`/sessions/${sessionId}/entry-points`,
|
||||
),
|
||||
|
||||
updateTriggerTask: (sessionId: string, triggerId: string, task: string) =>
|
||||
api.patch<{ trigger_id: string; task: string }>(
|
||||
updateTrigger: (
|
||||
sessionId: string,
|
||||
triggerId: string,
|
||||
patch: { task?: string; trigger_config?: Record<string, unknown> },
|
||||
) =>
|
||||
api.patch<{ trigger_id: string; task: string; trigger_config: Record<string, unknown> }>(
|
||||
`/sessions/${sessionId}/triggers/${triggerId}`,
|
||||
{ task },
|
||||
patch,
|
||||
),
|
||||
|
||||
graphs: (sessionId: string) =>
|
||||
|
||||
@@ -324,6 +324,7 @@ export type EventTypeName =
|
||||
| "node_retry"
|
||||
| "edge_traversed"
|
||||
| "context_compacted"
|
||||
| "context_usage_updated"
|
||||
| "webhook_received"
|
||||
| "custom"
|
||||
| "escalation_requested"
|
||||
@@ -337,7 +338,8 @@ export type EventTypeName =
|
||||
| "trigger_activated"
|
||||
| "trigger_deactivated"
|
||||
| "trigger_fired"
|
||||
| "trigger_removed";
|
||||
| "trigger_removed"
|
||||
| "trigger_updated";
|
||||
|
||||
export interface AgentEvent {
|
||||
type: EventTypeName;
|
||||
|
||||
@@ -1,770 +0,0 @@
|
||||
import { memo, useMemo, useState, useRef, useEffect, useCallback } from "react";
|
||||
import { Play, Pause, Loader2, CheckCircle2 } from "lucide-react";
|
||||
|
||||
export type NodeStatus = "running" | "complete" | "pending" | "error" | "looping";
|
||||
|
||||
export type NodeType = "execution" | "trigger";
|
||||
|
||||
export interface GraphNode {
|
||||
id: string;
|
||||
label: string;
|
||||
status: NodeStatus;
|
||||
nodeType?: NodeType;
|
||||
triggerType?: string;
|
||||
triggerConfig?: Record<string, unknown>;
|
||||
next?: string[];
|
||||
backEdges?: string[];
|
||||
iterations?: number;
|
||||
maxIterations?: number;
|
||||
statusLabel?: string;
|
||||
edgeLabels?: Record<string, string>;
|
||||
}
|
||||
|
||||
export type RunState = "idle" | "deploying" | "running";
|
||||
|
||||
interface AgentGraphProps {
|
||||
nodes: GraphNode[];
|
||||
title: string;
|
||||
onNodeClick?: (node: GraphNode) => void;
|
||||
onRun?: () => void;
|
||||
onPause?: () => void;
|
||||
version?: string;
|
||||
runState?: RunState;
|
||||
building?: boolean;
|
||||
queenPhase?: "planning" | "building" | "staging" | "running";
|
||||
}
|
||||
|
||||
// --- Extracted RunButton so hover state survives parent re-renders ---
|
||||
export interface RunButtonProps {
|
||||
runState: RunState;
|
||||
disabled: boolean;
|
||||
onRun: () => void;
|
||||
onPause: () => void;
|
||||
btnRef: React.Ref<HTMLButtonElement>;
|
||||
}
|
||||
|
||||
export const RunButton = memo(function RunButton({ runState, disabled, onRun, onPause, btnRef }: RunButtonProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
const showPause = runState === "running" && hovered;
|
||||
|
||||
return (
|
||||
<button
|
||||
ref={btnRef}
|
||||
onClick={runState === "running" ? onPause : onRun}
|
||||
disabled={runState === "deploying" || disabled}
|
||||
onMouseEnter={() => setHovered(true)}
|
||||
onMouseLeave={() => setHovered(false)}
|
||||
className={`flex items-center gap-1.5 px-2.5 py-1 rounded-md text-[11px] font-semibold transition-all duration-200 ${
|
||||
showPause
|
||||
? "bg-amber-500/15 text-amber-400 border border-amber-500/40 hover:bg-amber-500/25 active:scale-95 cursor-pointer"
|
||||
: runState === "running"
|
||||
? "bg-green-500/15 text-green-400 border border-green-500/30 cursor-pointer"
|
||||
: runState === "deploying"
|
||||
? "bg-primary/10 text-primary border border-primary/20 cursor-default"
|
||||
: disabled
|
||||
? "bg-muted/30 text-muted-foreground/40 border border-border/20 cursor-not-allowed"
|
||||
: "bg-primary/10 text-primary border border-primary/20 hover:bg-primary/20 hover:border-primary/40 active:scale-95"
|
||||
}`}
|
||||
>
|
||||
{runState === "deploying" ? (
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
) : showPause ? (
|
||||
<Pause className="w-3 h-3 fill-current" />
|
||||
) : runState === "running" ? (
|
||||
<CheckCircle2 className="w-3 h-3" />
|
||||
) : (
|
||||
<Play className="w-3 h-3 fill-current" />
|
||||
)}
|
||||
{runState === "deploying" ? "Deploying\u2026" : showPause ? "Pause" : runState === "running" ? "Running" : "Run"}
|
||||
</button>
|
||||
);
|
||||
});
|
||||
|
||||
const NODE_W_MAX = 180;
|
||||
const NODE_H = 44;
|
||||
const GAP_Y = 48;
|
||||
const TOP_Y = 30;
|
||||
const MARGIN_LEFT = 20;
|
||||
const MARGIN_RIGHT = 50; // space for back-edge arcs
|
||||
const SVG_BASE_W = 320;
|
||||
const GAP_X = 12;
|
||||
|
||||
// Read a CSS custom property value (space-separated HSL components)
|
||||
function cssVar(name: string): string {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
||||
}
|
||||
|
||||
type StatusColorSet = Record<NodeStatus, { dot: string; bg: string; border: string; glow: string }>;
|
||||
type TriggerColorSet = { bg: string; border: string; text: string; icon: string };
|
||||
|
||||
function buildStatusColors(): StatusColorSet {
|
||||
const running = cssVar("--node-running") || "45 95% 58%";
|
||||
const looping = cssVar("--node-looping") || "38 90% 55%";
|
||||
const complete = cssVar("--node-complete") || "43 70% 45%";
|
||||
const pending = cssVar("--node-pending") || "35 15% 28%";
|
||||
const pendingBg = cssVar("--node-pending-bg") || "35 10% 12%";
|
||||
const pendingBorder = cssVar("--node-pending-border") || "35 10% 20%";
|
||||
const error = cssVar("--node-error") || "0 65% 55%";
|
||||
|
||||
return {
|
||||
running: {
|
||||
dot: `hsl(${running})`,
|
||||
bg: `hsl(${running} / 0.08)`,
|
||||
border: `hsl(${running} / 0.5)`,
|
||||
glow: `hsl(${running} / 0.15)`,
|
||||
},
|
||||
looping: {
|
||||
dot: `hsl(${looping})`,
|
||||
bg: `hsl(${looping} / 0.08)`,
|
||||
border: `hsl(${looping} / 0.5)`,
|
||||
glow: `hsl(${looping} / 0.15)`,
|
||||
},
|
||||
complete: {
|
||||
dot: `hsl(${complete})`,
|
||||
bg: `hsl(${complete} / 0.05)`,
|
||||
border: `hsl(${complete} / 0.25)`,
|
||||
glow: "none",
|
||||
},
|
||||
pending: {
|
||||
dot: `hsl(${pending})`,
|
||||
bg: `hsl(${pendingBg})`,
|
||||
border: `hsl(${pendingBorder})`,
|
||||
glow: "none",
|
||||
},
|
||||
error: {
|
||||
dot: `hsl(${error})`,
|
||||
bg: `hsl(${error} / 0.06)`,
|
||||
border: `hsl(${error} / 0.3)`,
|
||||
glow: `hsl(${error} / 0.1)`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function buildTriggerColors(): TriggerColorSet {
|
||||
const bg = cssVar("--trigger-bg") || "210 25% 14%";
|
||||
const border = cssVar("--trigger-border") || "210 30% 30%";
|
||||
const text = cssVar("--trigger-text") || "210 30% 65%";
|
||||
const icon = cssVar("--trigger-icon") || "210 40% 55%";
|
||||
return {
|
||||
bg: `hsl(${bg})`,
|
||||
border: `hsl(${border})`,
|
||||
text: `hsl(${text})`,
|
||||
icon: `hsl(${icon})`,
|
||||
};
|
||||
}
|
||||
|
||||
/** Hook that reads node/trigger colors from CSS vars and updates on theme changes. */
|
||||
function useThemeColors() {
|
||||
const [statusColors, setStatusColors] = useState<StatusColorSet>(buildStatusColors);
|
||||
const [triggerColors, setTriggerColors] = useState<TriggerColorSet>(buildTriggerColors);
|
||||
|
||||
useEffect(() => {
|
||||
const rebuild = () => {
|
||||
setStatusColors(buildStatusColors());
|
||||
setTriggerColors(buildTriggerColors());
|
||||
};
|
||||
const obs = new MutationObserver(rebuild);
|
||||
obs.observe(document.documentElement, { attributes: true, attributeFilter: ["class", "style"] });
|
||||
return () => obs.disconnect();
|
||||
}, []);
|
||||
|
||||
return { statusColors, triggerColors };
|
||||
}
|
||||
|
||||
// Active trigger — brighter, more saturated blue
|
||||
const activeTriggerColors = {
|
||||
bg: "hsl(210,30%,18%)",
|
||||
border: "hsl(210,50%,50%)",
|
||||
text: "hsl(210,40%,75%)",
|
||||
icon: "hsl(210,60%,65%)",
|
||||
};
|
||||
|
||||
const triggerIcons: Record<string, string> = {
|
||||
webhook: "\u26A1", // lightning bolt
|
||||
timer: "\u23F1", // stopwatch
|
||||
api: "\u2192", // right arrow
|
||||
event: "\u223F", // sine wave
|
||||
};
|
||||
|
||||
/** Truncate label to fit within `availablePx` at the given fontSize. */
|
||||
function truncateLabel(label: string, availablePx: number, fontSize: number): string {
|
||||
const avgCharW = fontSize * 0.58;
|
||||
const maxChars = Math.floor(availablePx / avgCharW);
|
||||
if (label.length <= maxChars) return label;
|
||||
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
|
||||
}
|
||||
|
||||
// ─── Pan & Zoom wrapper ───
|
||||
function PanZoomSvg({ svgW, svgH, className, children }: { svgW: number; svgH: number; className?: string; children: React.ReactNode }) {
|
||||
const [zoom, setZoom] = useState(1);
|
||||
const [pan, setPan] = useState({ x: 0, y: 0 });
|
||||
const [dragging, setDragging] = useState(false);
|
||||
const dragStart = useRef({ x: 0, y: 0, panX: 0, panY: 0 });
|
||||
|
||||
const MIN_ZOOM = 0.4;
|
||||
const MAX_ZOOM = 3;
|
||||
|
||||
const handleWheel = useCallback((e: React.WheelEvent) => {
|
||||
e.preventDefault();
|
||||
const delta = e.deltaY > 0 ? 0.9 : 1.1;
|
||||
setZoom(z => Math.min(MAX_ZOOM, Math.max(MIN_ZOOM, z * delta)));
|
||||
}, []);
|
||||
|
||||
const handleMouseDown = useCallback((e: React.MouseEvent) => {
|
||||
if (e.button !== 0) return;
|
||||
setDragging(true);
|
||||
dragStart.current = { x: e.clientX, y: e.clientY, panX: pan.x, panY: pan.y };
|
||||
}, [pan]);
|
||||
|
||||
const handleMouseMove = useCallback((e: React.MouseEvent) => {
|
||||
if (!dragging) return;
|
||||
setPan({
|
||||
x: dragStart.current.panX + (e.clientX - dragStart.current.x),
|
||||
y: dragStart.current.panY + (e.clientY - dragStart.current.y),
|
||||
});
|
||||
}, [dragging]);
|
||||
|
||||
const handleMouseUp = useCallback(() => setDragging(false), []);
|
||||
|
||||
const resetView = useCallback(() => {
|
||||
setZoom(1);
|
||||
setPan({ x: 0, y: 0 });
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="flex-1 relative overflow-hidden px-1 pb-5">
|
||||
<div
|
||||
onWheel={handleWheel}
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseUp={handleMouseUp}
|
||||
onMouseLeave={handleMouseUp}
|
||||
className="w-full h-full"
|
||||
style={{ cursor: dragging ? "grabbing" : "grab" }}
|
||||
>
|
||||
<svg
|
||||
width="100%"
|
||||
viewBox={`0 0 ${svgW} ${svgH}`}
|
||||
preserveAspectRatio="xMidYMin meet"
|
||||
className={`select-none ${className || ""}`}
|
||||
style={{
|
||||
fontFamily: "'Inter', system-ui, sans-serif",
|
||||
transform: `translate(${pan.x}px, ${pan.y}px) scale(${zoom})`,
|
||||
transformOrigin: "center top",
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
{/* Zoom controls */}
|
||||
<div className="absolute bottom-7 right-3 flex items-center gap-1 bg-card/80 backdrop-blur-sm border border-border/40 rounded-lg p-0.5 shadow-sm">
|
||||
<button
|
||||
onClick={() => setZoom(z => Math.min(MAX_ZOOM, z * 1.2))}
|
||||
className="w-6 h-6 flex items-center justify-center rounded text-muted-foreground hover:text-foreground hover:bg-muted/60 transition-colors text-xs font-bold"
|
||||
aria-label="Zoom in"
|
||||
>+</button>
|
||||
<button
|
||||
onClick={resetView}
|
||||
className="px-1.5 h-6 flex items-center justify-center rounded text-[10px] font-mono text-muted-foreground hover:text-foreground hover:bg-muted/60 transition-colors"
|
||||
aria-label="Reset zoom"
|
||||
>{Math.round(zoom * 100)}%</button>
|
||||
<button
|
||||
onClick={() => setZoom(z => Math.max(MIN_ZOOM, z * 0.8))}
|
||||
className="w-6 h-6 flex items-center justify-center rounded text-muted-foreground hover:text-foreground hover:bg-muted/60 transition-colors text-xs font-bold"
|
||||
aria-label="Zoom out"
|
||||
>{"\u2212"}</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function AgentGraph({ nodes, title: _title, onNodeClick, onRun, onPause, version, runState: externalRunState, building, queenPhase }: AgentGraphProps) {
|
||||
const [localRunState, setLocalRunState] = useState<RunState>("idle");
|
||||
const runState = externalRunState ?? localRunState;
|
||||
const runBtnRef = useRef<HTMLButtonElement>(null);
|
||||
const { statusColors, triggerColors } = useThemeColors();
|
||||
|
||||
const handleRun = () => {
|
||||
if (runState !== "idle") return;
|
||||
if (onRun) {
|
||||
onRun();
|
||||
} else {
|
||||
setLocalRunState("deploying");
|
||||
setTimeout(() => setLocalRunState("running"), 1800);
|
||||
setTimeout(() => setLocalRunState("idle"), 5000);
|
||||
}
|
||||
};
|
||||
|
||||
const idxMap = useMemo(() => Object.fromEntries(nodes.map((n, i) => [n.id, i])), [nodes]);
|
||||
|
||||
const backEdges = useMemo(() => {
|
||||
const edges: { fromIdx: number; toIdx: number }[] = [];
|
||||
nodes.forEach((n, i) => {
|
||||
(n.next || []).forEach((toId) => {
|
||||
const toIdx = idxMap[toId];
|
||||
if (toIdx !== undefined && toIdx <= i) edges.push({ fromIdx: i, toIdx });
|
||||
});
|
||||
(n.backEdges || []).forEach((toId) => {
|
||||
const toIdx = idxMap[toId];
|
||||
if (toIdx !== undefined) edges.push({ fromIdx: i, toIdx });
|
||||
});
|
||||
});
|
||||
return edges;
|
||||
}, [nodes, idxMap]);
|
||||
|
||||
const forwardEdges = useMemo(() => {
|
||||
const edges: { fromIdx: number; toIdx: number; fanCount: number; fanIndex: number; label?: string }[] = [];
|
||||
nodes.forEach((n, i) => {
|
||||
const targets = (n.next || [])
|
||||
.map((toId) => ({ toId, toIdx: idxMap[toId] }))
|
||||
.filter((t): t is { toId: string; toIdx: number } => t.toIdx !== undefined && t.toIdx > i);
|
||||
targets.forEach(({ toId, toIdx }, fi) => {
|
||||
edges.push({
|
||||
fromIdx: i,
|
||||
toIdx,
|
||||
fanCount: targets.length,
|
||||
fanIndex: fi,
|
||||
label: n.edgeLabels?.[toId],
|
||||
});
|
||||
});
|
||||
});
|
||||
return edges;
|
||||
}, [nodes, idxMap]);
|
||||
|
||||
// --- Layer-based layout computation ---
|
||||
const layout = useMemo(() => {
|
||||
if (nodes.length === 0) {
|
||||
return { layers: [] as number[], cols: [] as number[], maxCols: 1, nodeW: NODE_W_MAX, colSpacing: 0, firstColX: MARGIN_LEFT };
|
||||
}
|
||||
|
||||
// 1. Build reverse adjacency from forward edges (who are the parents of each node)
|
||||
const parents = new Map<number, number[]>();
|
||||
nodes.forEach((_, i) => parents.set(i, []));
|
||||
forwardEdges.forEach((e) => {
|
||||
parents.get(e.toIdx)!.push(e.fromIdx);
|
||||
});
|
||||
|
||||
// 2. Assign layers via longest-path from entry
|
||||
const layers = new Array(nodes.length).fill(0);
|
||||
for (let i = 0; i < nodes.length; i++) {
|
||||
const pars = parents.get(i) || [];
|
||||
if (pars.length > 0) {
|
||||
layers[i] = Math.max(...pars.map((p) => layers[p])) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Group nodes by layer
|
||||
const layerGroups = new Map<number, number[]>();
|
||||
layers.forEach((l, i) => {
|
||||
const group = layerGroups.get(l) || [];
|
||||
group.push(i);
|
||||
layerGroups.set(l, group);
|
||||
});
|
||||
|
||||
// 4. Compute max columns and dynamic node width
|
||||
let maxCols = 1;
|
||||
layerGroups.forEach((group) => {
|
||||
maxCols = Math.max(maxCols, group.length);
|
||||
});
|
||||
|
||||
const usableW = SVG_BASE_W - MARGIN_LEFT - MARGIN_RIGHT;
|
||||
const nodeW = Math.min(NODE_W_MAX, Math.floor((usableW - (maxCols - 1) * GAP_X) / maxCols));
|
||||
const colSpacing = nodeW + GAP_X;
|
||||
const totalNodesW = maxCols * nodeW + (maxCols - 1) * GAP_X;
|
||||
const firstColX = MARGIN_LEFT + (usableW - totalNodesW) / 2;
|
||||
|
||||
// 5. Assign columns within each layer (centered, ordered by parent column)
|
||||
const cols = new Array(nodes.length).fill(0);
|
||||
layerGroups.forEach((group) => {
|
||||
if (group.length === 1) {
|
||||
// Center single node: place at middle column
|
||||
cols[group[0]] = (maxCols - 1) / 2;
|
||||
} else {
|
||||
// Sort group by average parent column to reduce crossings
|
||||
const sorted = [...group].sort((a, b) => {
|
||||
const aParents = parents.get(a) || [];
|
||||
const bParents = parents.get(b) || [];
|
||||
const aAvg = aParents.length > 0 ? aParents.reduce((s, p) => s + cols[p], 0) / aParents.length : 0;
|
||||
const bAvg = bParents.length > 0 ? bParents.reduce((s, p) => s + cols[p], 0) / bParents.length : 0;
|
||||
return aAvg - bAvg;
|
||||
});
|
||||
// Spread evenly, centered within maxCols
|
||||
const offset = (maxCols - group.length) / 2;
|
||||
sorted.forEach((nodeIdx, i) => {
|
||||
cols[nodeIdx] = offset + i;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return { layers, cols, maxCols, nodeW, colSpacing, firstColX };
|
||||
}, [nodes, forwardEdges]);
|
||||
|
||||
if (nodes.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
<div className="px-5 pt-4 pb-2 flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">Pipeline</p>
|
||||
{version && (
|
||||
<span className="text-[10px] font-mono font-medium text-muted-foreground/60 border border-border/30 rounded px-1 py-0.5 leading-none">
|
||||
{version}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<RunButton runState={runState} disabled={nodes.length === 0 || queenPhase === "building" || queenPhase === "planning"} onRun={handleRun} onPause={onPause ?? (() => {})} btnRef={runBtnRef} />
|
||||
</div>
|
||||
<div className="flex-1 flex items-center justify-center px-5">
|
||||
{building ? (
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<Loader2 className="w-6 h-6 animate-spin text-primary/60" />
|
||||
<p className="text-xs text-muted-foreground/80 text-center">Building agent...</p>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-xs text-muted-foreground/60 text-center italic">No pipeline configured yet.<br/>Chat with the Queen to get started.</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const { layers, cols, nodeW, colSpacing, firstColX } = layout;
|
||||
|
||||
const nodePos = (i: number) => ({
|
||||
x: firstColX + cols[i] * colSpacing,
|
||||
y: TOP_Y + layers[i] * (NODE_H + GAP_Y),
|
||||
});
|
||||
|
||||
const maxLayer = nodes.length > 0 ? Math.max(...layers) : 0;
|
||||
const svgHeight = TOP_Y * 2 + (maxLayer + 1) * NODE_H + maxLayer * GAP_Y + 10;
|
||||
const backEdgeSpace = backEdges.length > 0 ? MARGIN_RIGHT + backEdges.length * 18 : 20;
|
||||
const svgWidth = Math.max(SVG_BASE_W, firstColX + layout.maxCols * nodeW + (layout.maxCols - 1) * GAP_X + backEdgeSpace);
|
||||
|
||||
// Check if a skip-level forward edge would collide with intermediate nodes
|
||||
const hasCollision = (fromLayer: number, toLayer: number, fromX: number, toX: number): boolean => {
|
||||
const minX = Math.min(fromX, toX);
|
||||
const maxX = Math.max(fromX, toX) + nodeW;
|
||||
for (let i = 0; i < nodes.length; i++) {
|
||||
const l = layers[i];
|
||||
if (l > fromLayer && l < toLayer) {
|
||||
const nx = firstColX + cols[i] * colSpacing;
|
||||
// Check horizontal overlap
|
||||
if (nx < maxX && nx + nodeW > minX) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const renderForwardEdge = (edge: { fromIdx: number; toIdx: number; fanCount: number; fanIndex: number; label?: string }, i: number) => {
|
||||
const from = nodePos(edge.fromIdx);
|
||||
const to = nodePos(edge.toIdx);
|
||||
const fromCenterX = from.x + nodeW / 2;
|
||||
const toCenterX = to.x + nodeW / 2;
|
||||
const y1 = from.y + NODE_H;
|
||||
const y2 = to.y;
|
||||
|
||||
// Fan-out: spread exit points across the source node's bottom
|
||||
let startX = fromCenterX;
|
||||
if (edge.fanCount > 1) {
|
||||
const spread = nodeW * 0.5;
|
||||
const step = edge.fanCount > 1 ? spread / (edge.fanCount - 1) : 0;
|
||||
startX = fromCenterX - spread / 2 + edge.fanIndex * step;
|
||||
}
|
||||
|
||||
const midY = (y1 + y2) / 2;
|
||||
const fromLayer = layers[edge.fromIdx];
|
||||
const toLayer = layers[edge.toIdx];
|
||||
const skipsLayers = toLayer - fromLayer > 1;
|
||||
|
||||
let d: string;
|
||||
if (skipsLayers && hasCollision(fromLayer, toLayer, from.x, to.x)) {
|
||||
// Route around intermediate nodes: orthogonal detour to the left
|
||||
const detourX = Math.min(from.x, to.x) - nodeW * 0.4;
|
||||
d = `M ${startX} ${y1} L ${startX} ${midY} L ${detourX} ${midY} L ${detourX} ${y2 - 10} L ${toCenterX} ${y2 - 10} L ${toCenterX} ${y2}`;
|
||||
} else if (Math.abs(startX - toCenterX) < 2) {
|
||||
// Straight vertical line when aligned
|
||||
d = `M ${startX} ${y1} L ${toCenterX} ${y2}`;
|
||||
} else {
|
||||
// Orthogonal: down, across, down
|
||||
d = `M ${startX} ${y1} L ${startX} ${midY} L ${toCenterX} ${midY} L ${toCenterX} ${y2}`;
|
||||
}
|
||||
|
||||
const fromNode = nodes[edge.fromIdx];
|
||||
const isActive = fromNode.status === "complete" || fromNode.status === "running" || fromNode.status === "looping";
|
||||
const strokeColor = isActive ? statusColors.complete.border : statusColors.pending.border;
|
||||
const arrowColor = isActive ? statusColors.complete.dot : statusColors.pending.border;
|
||||
|
||||
return (
|
||||
<g key={`fwd-${i}`}>
|
||||
<path d={d} fill="none" stroke={strokeColor} strokeWidth={1.5} />
|
||||
<polygon
|
||||
points={`${toCenterX - 4},${y2 - 6} ${toCenterX + 4},${y2 - 6} ${toCenterX},${y2 - 1}`}
|
||||
fill={arrowColor}
|
||||
/>
|
||||
{edge.label && (
|
||||
<text
|
||||
x={(startX + toCenterX) / 2 + 8}
|
||||
y={midY - 2}
|
||||
fill={statusColors.pending.dot}
|
||||
fontSize={9}
|
||||
fontStyle="italic"
|
||||
>
|
||||
{edge.label}
|
||||
</text>
|
||||
)}
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderBackEdge = (edge: { fromIdx: number; toIdx: number }, i: number) => {
|
||||
const from = nodePos(edge.fromIdx);
|
||||
const to = nodePos(edge.toIdx);
|
||||
|
||||
const rightX = Math.max(from.x, to.x) + nodeW;
|
||||
const rightOffset = 28 + i * 18;
|
||||
const startX = from.x + nodeW;
|
||||
const startY = from.y + NODE_H / 2;
|
||||
const endX = to.x + nodeW;
|
||||
const endY = to.y + NODE_H / 2;
|
||||
const curveX = rightX + rightOffset;
|
||||
const r = 12;
|
||||
|
||||
const fromNode = nodes[edge.fromIdx];
|
||||
const isActive = fromNode.status === "complete" || fromNode.status === "running" || fromNode.status === "looping";
|
||||
const color = isActive ? statusColors.looping.border : statusColors.pending.border;
|
||||
|
||||
// Bezier curve with rounded corners (kept as curves for back edges)
|
||||
const path = `M ${startX} ${startY} C ${startX + r} ${startY}, ${curveX} ${startY}, ${curveX} ${startY - r} L ${curveX} ${endY + r} C ${curveX} ${endY}, ${endX + r} ${endY}, ${endX + 6} ${endY}`;
|
||||
|
||||
return (
|
||||
<g key={`back-${i}`}>
|
||||
<path d={path} fill="none" stroke={color} strokeWidth={1.5} strokeDasharray="4 3" />
|
||||
<polygon
|
||||
points={`${endX + 6},${endY - 3} ${endX + 6},${endY + 3} ${endX},${endY}`}
|
||||
fill={isActive ? statusColors.looping.dot : statusColors.pending.border}
|
||||
/>
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderTriggerNode = (node: GraphNode, i: number) => {
|
||||
const pos = nodePos(i);
|
||||
const icon = triggerIcons[node.triggerType || ""] || "\u26A1";
|
||||
const triggerFontSize = nodeW < 140 ? 10.5 : 11.5;
|
||||
const triggerAvailW = nodeW - 38;
|
||||
const triggerDisplayLabel = truncateLabel(node.label, triggerAvailW, triggerFontSize);
|
||||
const nextFireIn = node.triggerConfig?.next_fire_in as number | undefined;
|
||||
const isActive = node.status === "running" || node.status === "complete";
|
||||
const colors = isActive ? activeTriggerColors : triggerColors;
|
||||
|
||||
// Format countdown for display below node
|
||||
let countdownLabel: string | null = null;
|
||||
if (isActive && nextFireIn != null && nextFireIn > 0) {
|
||||
const h = Math.floor(nextFireIn / 3600);
|
||||
const m = Math.floor((nextFireIn % 3600) / 60);
|
||||
const s = Math.floor(nextFireIn % 60);
|
||||
countdownLabel = h > 0
|
||||
? `next in ${h}h ${String(m).padStart(2, "0")}m`
|
||||
: `next in ${m}m ${String(s).padStart(2, "0")}s`;
|
||||
}
|
||||
|
||||
// Status label below countdown
|
||||
const statusLabel = isActive ? "active" : "inactive";
|
||||
const statusColor = isActive ? "hsl(140,40%,50%)" : "hsl(210,20%,40%)";
|
||||
|
||||
return (
|
||||
<g key={node.id} onClick={() => onNodeClick?.(node)} style={{ cursor: onNodeClick ? "pointer" : "default" }}>
|
||||
<title>{node.label}</title>
|
||||
{/* Pill-shaped background — solid border when active, dashed when inactive */}
|
||||
<rect
|
||||
x={pos.x} y={pos.y}
|
||||
width={nodeW} height={NODE_H}
|
||||
rx={NODE_H / 2}
|
||||
fill={colors.bg}
|
||||
stroke={colors.border}
|
||||
strokeWidth={isActive ? 1.5 : 1}
|
||||
strokeDasharray={isActive ? undefined : "4 2"}
|
||||
/>
|
||||
|
||||
{/* Trigger type icon */}
|
||||
<text
|
||||
x={pos.x + 18} y={pos.y + NODE_H / 2}
|
||||
fill={colors.icon} fontSize={13}
|
||||
textAnchor="middle" dominantBaseline="middle"
|
||||
>
|
||||
{icon}
|
||||
</text>
|
||||
|
||||
{/* Label */}
|
||||
<text
|
||||
x={pos.x + 32} y={pos.y + NODE_H / 2}
|
||||
fill={colors.text}
|
||||
fontSize={triggerFontSize}
|
||||
fontWeight={500}
|
||||
dominantBaseline="middle"
|
||||
letterSpacing="0.01em"
|
||||
>
|
||||
{triggerDisplayLabel}
|
||||
</text>
|
||||
|
||||
{/* Countdown label below node */}
|
||||
{countdownLabel && (
|
||||
<text
|
||||
x={pos.x + nodeW / 2} y={pos.y + NODE_H + 13}
|
||||
fill={triggerColors.text} fontSize={9.5}
|
||||
textAnchor="middle" fontStyle="italic" opacity={0.7}
|
||||
>
|
||||
{countdownLabel}
|
||||
</text>
|
||||
)}
|
||||
|
||||
{/* Status label */}
|
||||
<text
|
||||
x={pos.x + nodeW / 2} y={pos.y + NODE_H + (countdownLabel ? 25 : 13)}
|
||||
fill={statusColor} fontSize={9}
|
||||
textAnchor="middle" opacity={0.8}
|
||||
>
|
||||
{statusLabel}
|
||||
</text>
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderNode = (node: GraphNode, i: number) => {
|
||||
if (node.nodeType === "trigger") return renderTriggerNode(node, i);
|
||||
|
||||
const pos = nodePos(i);
|
||||
const isActive = node.status === "running" || node.status === "looping";
|
||||
const isDone = node.status === "complete";
|
||||
const colors = statusColors[node.status];
|
||||
|
||||
const fontSize = nodeW < 140 ? 10.5 : 12.5;
|
||||
const labelAvailW = nodeW - 38;
|
||||
const displayLabel = truncateLabel(node.label, labelAvailW, fontSize);
|
||||
|
||||
return (
|
||||
<g key={node.id} onClick={() => onNodeClick?.(node)} style={{ cursor: onNodeClick ? "pointer" : "default" }}>
|
||||
<title>{node.label}</title>
|
||||
{/* Ambient glow for active nodes */}
|
||||
{isActive && (
|
||||
<>
|
||||
<rect
|
||||
x={pos.x - 4} y={pos.y - 4}
|
||||
width={nodeW + 8} height={NODE_H + 8}
|
||||
rx={16} fill={colors.glow}
|
||||
/>
|
||||
<rect
|
||||
x={pos.x - 2} y={pos.y - 2}
|
||||
width={nodeW + 4} height={NODE_H + 4}
|
||||
rx={14} fill="none" stroke={colors.dot} strokeWidth={1} opacity={0.25}
|
||||
style={{ animation: "pulse-ring 2.5s ease-out infinite" }}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Node background */}
|
||||
<rect
|
||||
x={pos.x} y={pos.y}
|
||||
width={nodeW} height={NODE_H}
|
||||
rx={12}
|
||||
fill={colors.bg}
|
||||
stroke={colors.border}
|
||||
strokeWidth={isActive ? 1.5 : 1}
|
||||
/>
|
||||
|
||||
{/* Status dot */}
|
||||
<circle cx={pos.x + 18} cy={pos.y + NODE_H / 2} r={4.5} fill={colors.dot} />
|
||||
{isActive && (
|
||||
<circle cx={pos.x + 18} cy={pos.y + NODE_H / 2} r={7} fill="none" stroke={colors.dot} strokeWidth={1} opacity={0.3}>
|
||||
<animate attributeName="r" values="7;11;7" dur="2s" repeatCount="indefinite" />
|
||||
<animate attributeName="opacity" values="0.3;0;0.3" dur="2s" repeatCount="indefinite" />
|
||||
</circle>
|
||||
)}
|
||||
|
||||
{/* Check mark for complete */}
|
||||
{isDone && (
|
||||
<text
|
||||
x={pos.x + 18} y={pos.y + NODE_H / 2 + 1}
|
||||
fill={colors.dot} fontSize={8} fontWeight={700}
|
||||
textAnchor="middle" dominantBaseline="middle"
|
||||
>
|
||||
✓
|
||||
</text>
|
||||
)}
|
||||
|
||||
{/* Label -- truncated with ellipsis for narrow nodes */}
|
||||
<text
|
||||
x={pos.x + 32} y={pos.y + NODE_H / 2}
|
||||
fill={isActive ? statusColors.running.dot : isDone ? statusColors.complete.dot : statusColors.pending.dot}
|
||||
fontSize={fontSize}
|
||||
fontWeight={isActive ? 600 : isDone ? 500 : 400}
|
||||
dominantBaseline="middle"
|
||||
letterSpacing="0.01em"
|
||||
>
|
||||
{displayLabel}
|
||||
</text>
|
||||
|
||||
{/* Status label for active nodes */}
|
||||
{node.statusLabel && isActive && (
|
||||
<text
|
||||
x={pos.x + nodeW + 10} y={pos.y + NODE_H / 2}
|
||||
fill={statusColors.running.dot} fontSize={10.5} fontStyle="italic"
|
||||
dominantBaseline="middle" opacity={0.8}
|
||||
>
|
||||
{node.statusLabel}
|
||||
</text>
|
||||
)}
|
||||
|
||||
{/* Iteration badge */}
|
||||
{node.iterations !== undefined && node.iterations > 0 && (
|
||||
<g>
|
||||
<rect
|
||||
x={pos.x + nodeW - 36} y={pos.y + NODE_H / 2 - 8}
|
||||
width={26} height={16} rx={8}
|
||||
fill={colors.dot} opacity={0.15}
|
||||
/>
|
||||
<text
|
||||
x={pos.x + nodeW - 23} y={pos.y + NODE_H / 2}
|
||||
fill={colors.dot} fontSize={9} fontWeight={600}
|
||||
textAnchor="middle" dominantBaseline="middle" opacity={0.8}
|
||||
>
|
||||
{node.iterations}{node.maxIterations ? `/${node.maxIterations}` : "\u00d7"}
|
||||
</text>
|
||||
</g>
|
||||
)}
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
{/* Compact sub-label */}
|
||||
<div className="px-5 pt-4 pb-2 flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">Pipeline</p>
|
||||
{version && (
|
||||
<span className="text-[10px] font-mono font-medium text-muted-foreground/60 border border-border/30 rounded px-1 py-0.5 leading-none">
|
||||
{version}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<RunButton runState={runState} disabled={nodes.length === 0} onRun={handleRun} onPause={onPause ?? (() => {})} btnRef={runBtnRef} />
|
||||
</div>
|
||||
|
||||
{/* Graph */}
|
||||
<PanZoomSvg svgW={svgWidth} svgH={svgHeight} className={building ? "opacity-30" : ""}>
|
||||
{forwardEdges.map((e, i) => renderForwardEdge(e, i))}
|
||||
{backEdges.map((e, i) => renderBackEdge(e, i))}
|
||||
{nodes.map((n, i) => renderNode(n, i))}
|
||||
</PanZoomSvg>
|
||||
{building && (
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<Loader2 className="w-6 h-6 animate-spin text-primary/60" />
|
||||
<p className="text-xs text-muted-foreground/80">Rebuilding agent...</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,8 +1,16 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { memo, useState, useRef, useEffect, useMemo } from "react";
|
||||
import { Send, Square, Crown, Cpu, Check, Loader2 } from "lucide-react";
|
||||
|
||||
export interface ContextUsageEntry {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
import QuestionWidget from "@/components/QuestionWidget";
|
||||
import MultiQuestionWidget from "@/components/MultiQuestionWidget";
|
||||
import ParallelSubagentBubble, { type SubagentGroup } from "@/components/ParallelSubagentBubble";
|
||||
|
||||
export interface ChatMessage {
|
||||
id: string;
|
||||
@@ -18,6 +26,10 @@ export interface ChatMessage {
|
||||
createdAt?: number;
|
||||
/** Queen phase active when this message was created */
|
||||
phase?: "planning" | "building" | "staging" | "running";
|
||||
/** Backend node_id that produced this message — used for subagent grouping */
|
||||
nodeId?: string;
|
||||
/** Backend execution_id for this message */
|
||||
executionId?: string;
|
||||
}
|
||||
|
||||
interface ChatPanelProps {
|
||||
@@ -47,6 +59,8 @@ interface ChatPanelProps {
|
||||
onQuestionDismiss?: () => void;
|
||||
/** Queen operating phase — shown as a tag on queen messages */
|
||||
queenPhase?: "planning" | "building" | "staging" | "running";
|
||||
/** Context window usage for queen and workers */
|
||||
contextUsage?: Record<string, ContextUsageEntry>;
|
||||
}
|
||||
|
||||
const queenColor = "hsl(45,95%,58%)";
|
||||
@@ -241,7 +255,7 @@ const MessageBubble = memo(function MessageBubble({ msg, queenPhase }: { msg: Ch
|
||||
);
|
||||
}, (prev, next) => prev.msg.id === next.msg.id && prev.msg.content === next.msg.content && prev.msg.phase === next.msg.phase && prev.queenPhase === next.queenPhase);
|
||||
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase }: ChatPanelProps) {
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase, contextUsage }: ChatPanelProps) {
|
||||
const [input, setInput] = useState("");
|
||||
const [readMap, setReadMap] = useState<Record<string, number>>({});
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
@@ -251,9 +265,93 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
|
||||
const threadMessages = messages.filter((m) => {
|
||||
if (m.type === "system" && !m.thread) return false;
|
||||
return m.thread === activeThread;
|
||||
if (m.thread !== activeThread) return false;
|
||||
// Hide queen messages whose content is whitespace-only — these are
|
||||
// tool-use-only turns that have no visible text. During live operation
|
||||
// tool pills provide context, but on resume the pills are gone so
|
||||
// the empty bubble is meaningless.
|
||||
if (m.role === "queen" && !m.type && (!m.content || !m.content.trim())) return false;
|
||||
return true;
|
||||
});
|
||||
|
||||
// Group subagent messages into parallel bubbles.
|
||||
// A subagent message has nodeId containing ":subagent:".
|
||||
// The run only ends on hard boundaries (user messages, run_dividers)
|
||||
// so interleaved queen/tool/system messages don't fragment the bubble.
|
||||
type RenderItem =
|
||||
| { kind: "message"; msg: ChatMessage }
|
||||
| { kind: "parallel"; groupId: string; groups: SubagentGroup[] };
|
||||
|
||||
const renderItems = useMemo<RenderItem[]>(() => {
|
||||
const items: RenderItem[] = [];
|
||||
let i = 0;
|
||||
while (i < threadMessages.length) {
|
||||
const msg = threadMessages[i];
|
||||
const isSubagent = msg.nodeId?.includes(":subagent:");
|
||||
if (!isSubagent) {
|
||||
items.push({ kind: "message", msg });
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Start a subagent run. Collect all subagent messages, allowing
|
||||
// non-subagent messages in between (they render as normal items
|
||||
// before the bubble). Only break on hard boundaries.
|
||||
const subagentMsgs: ChatMessage[] = [];
|
||||
const interleaved: { idx: number; msg: ChatMessage }[] = [];
|
||||
const firstId = msg.id;
|
||||
|
||||
while (i < threadMessages.length) {
|
||||
const m = threadMessages[i];
|
||||
const isSa = m.nodeId?.includes(":subagent:");
|
||||
|
||||
if (isSa) {
|
||||
subagentMsgs.push(m);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Hard boundary — stop the run
|
||||
if (m.type === "user" || m.type === "run_divider") break;
|
||||
|
||||
// Worker message from a non-subagent node means the graph has
|
||||
// moved on to the next stage. Close the bubble even if some
|
||||
// subagents are still streaming in the background.
|
||||
if (m.role === "worker" && m.nodeId && !m.nodeId.includes(":subagent:")) break;
|
||||
|
||||
// Soft interruption (queen output, system, tool_status without
|
||||
// nodeId) — render it normally but keep the subagent run going
|
||||
interleaved.push({ idx: items.length + interleaved.length, msg: m });
|
||||
i++;
|
||||
}
|
||||
|
||||
// Emit interleaved messages first (before the bubble)
|
||||
for (const { msg: im } of interleaved) {
|
||||
items.push({ kind: "message", msg: im });
|
||||
}
|
||||
|
||||
// Build the single parallel bubble from all collected subagent msgs
|
||||
if (subagentMsgs.length > 0) {
|
||||
const byNode = new Map<string, ChatMessage[]>();
|
||||
for (const m of subagentMsgs) {
|
||||
const nid = m.nodeId!;
|
||||
if (!byNode.has(nid)) byNode.set(nid, []);
|
||||
byNode.get(nid)!.push(m);
|
||||
}
|
||||
const groups: SubagentGroup[] = [];
|
||||
for (const [nodeId, msgs] of byNode) {
|
||||
groups.push({
|
||||
nodeId,
|
||||
messages: msgs,
|
||||
contextUsage: contextUsage?.[nodeId],
|
||||
});
|
||||
}
|
||||
items.push({ kind: "parallel", groupId: `par-${firstId}`, groups });
|
||||
}
|
||||
}
|
||||
return items;
|
||||
}, [threadMessages, contextUsage]);
|
||||
|
||||
// Mark current thread as read
|
||||
useEffect(() => {
|
||||
const count = messages.filter((m) => m.thread === activeThread).length;
|
||||
@@ -299,11 +397,17 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
|
||||
{/* Messages */}
|
||||
<div ref={scrollRef} onScroll={handleScroll} className="flex-1 overflow-auto px-5 py-4 space-y-3">
|
||||
{threadMessages.map((msg) => (
|
||||
<div key={msg.id}>
|
||||
<MessageBubble msg={msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
))}
|
||||
{renderItems.map((item) =>
|
||||
item.kind === "parallel" ? (
|
||||
<div key={item.groupId}>
|
||||
<ParallelSubagentBubble groupId={item.groupId} groups={item.groups} />
|
||||
</div>
|
||||
) : (
|
||||
<div key={item.msg.id}>
|
||||
<MessageBubble msg={item.msg} queenPhase={queenPhase} />
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
|
||||
{/* Show typing indicator while waiting for first queen response (disabled + empty chat) */}
|
||||
{(isWaiting || (disabled && threadMessages.length === 0)) && (
|
||||
@@ -350,6 +454,57 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
|
||||
<div ref={bottomRef} />
|
||||
</div>
|
||||
|
||||
{/* Context window usage bar — sits between messages and input */}
|
||||
{(() => {
|
||||
if (!contextUsage) return null;
|
||||
const queenUsage = contextUsage["__queen__"];
|
||||
const workerEntries = Object.entries(contextUsage).filter(([k]) => k !== "__queen__");
|
||||
const workerUsage = workerEntries.length > 0
|
||||
? workerEntries.reduce((best, [, v]) => (v.usagePct > best.usagePct ? v : best), workerEntries[0][1])
|
||||
: undefined;
|
||||
if (!queenUsage && !workerUsage) return null;
|
||||
return (
|
||||
<div className="flex items-center gap-3 mx-4 px-3 py-1 rounded-lg bg-muted/30 border border-border/20 group/ctx flex-shrink-0">
|
||||
{queenUsage && (
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0" title={`Queen: ${(queenUsage.estimatedTokens / 1000).toFixed(1)}k / ${(queenUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${queenUsage.messageCount} messages`}>
|
||||
<Crown className="w-3 h-3 flex-shrink-0" style={{ color: "hsl(45,95%,58%)" }} />
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(queenUsage.usagePct, 100)}%`,
|
||||
backgroundColor: queenUsage.usagePct >= 90 ? "hsl(0,65%,55%)" : queenUsage.usagePct >= 70 ? "hsl(35,90%,55%)" : "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">{queenUsage.usagePct}%</span>
|
||||
<span className="hidden group-hover/ctx:inline">{(queenUsage.estimatedTokens / 1000).toFixed(1)}k / {(queenUsage.maxTokens / 1000).toFixed(0)}k</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{workerUsage && (
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0" title={`Worker: ${(workerUsage.estimatedTokens / 1000).toFixed(1)}k / ${(workerUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${workerUsage.messageCount} messages`}>
|
||||
<Cpu className="w-3 h-3 flex-shrink-0" style={{ color: "hsl(220,60%,55%)" }} />
|
||||
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(workerUsage.usagePct, 100)}%`,
|
||||
backgroundColor: workerUsage.usagePct >= 90 ? "hsl(0,65%,55%)" : workerUsage.usagePct >= 70 ? "hsl(35,90%,55%)" : "hsl(220,60%,55%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
|
||||
<span className="group-hover/ctx:hidden">{workerUsage.usagePct}%</span>
|
||||
<span className="hidden group-hover/ctx:inline">{(workerUsage.estimatedTokens / 1000).toFixed(1)}k / {(workerUsage.maxTokens / 1000).toFixed(0)}k</span>
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{/* Input area — question widget replaces textarea when a question is pending */}
|
||||
{pendingQuestions && pendingQuestions.length >= 2 && onMultiQuestionSubmit ? (
|
||||
<MultiQuestionWidget
|
||||
|
||||
@@ -1,13 +1,25 @@
|
||||
import { useEffect, useMemo, useRef, useState, useCallback } from "react";
|
||||
import { useEffect, useLayoutEffect, useMemo, useRef, useState, useCallback } from "react";
|
||||
import { Loader2 } from "lucide-react";
|
||||
import type { DraftGraph as DraftGraphData, DraftNode } from "@/api/types";
|
||||
import { RunButton } from "./AgentGraph";
|
||||
import type { GraphNode, RunState } from "./AgentGraph";
|
||||
import { RunButton } from "./RunButton";
|
||||
import type { GraphNode, RunState } from "./graph-types";
|
||||
import {
|
||||
cssVar,
|
||||
truncateLabel,
|
||||
TRIGGER_ICONS,
|
||||
ACTIVE_TRIGGER_COLORS,
|
||||
useTriggerColors,
|
||||
} from "@/lib/graphUtils";
|
||||
|
||||
// Read a CSS custom property value (space-separated HSL components)
|
||||
function cssVar(name: string): string {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
||||
}
|
||||
// ── Trigger layout constants ──
|
||||
const TRIGGER_H = 38; // pill height
|
||||
const TRIGGER_PILL_GAP_X = 16; // horizontal gap between multiple trigger pills
|
||||
const TRIGGER_ICON_X = 16; // icon center offset from pill left edge
|
||||
const TRIGGER_LABEL_X = 30; // label start offset from pill left edge
|
||||
const TRIGGER_LABEL_INSET = 38; // icon + padding subtracted from pill width for label space
|
||||
const TRIGGER_TEXT_Y = 11; // y-offset below pill for first text line (countdown or status)
|
||||
const TRIGGER_TEXT_STEP = 11; // additional y-offset for second text line when countdown present
|
||||
const TRIGGER_CLEARANCE = 30; // vertical space below pill for countdown + status text
|
||||
|
||||
interface DraftChromeColors {
|
||||
edge: string;
|
||||
@@ -73,7 +85,9 @@ function useDraftChromeColors() {
|
||||
type DraftNodeStatus = "pending" | "running" | "complete" | "error";
|
||||
|
||||
interface DraftGraphProps {
|
||||
draft: DraftGraphData;
|
||||
draft: DraftGraphData | null;
|
||||
/** The post-build originalDraft — animation fires when this changes to a new non-null value. */
|
||||
originalDraft?: DraftGraphData | null;
|
||||
onNodeClick?: (node: DraftNode) => void;
|
||||
/** Runtime node ID → list of original draft node IDs (post-dissolution mapping). */
|
||||
flowchartMap?: Record<string, string[]>;
|
||||
@@ -83,6 +97,8 @@ interface DraftGraphProps {
|
||||
onRuntimeNodeClick?: (runtimeNodeId: string) => void;
|
||||
/** True while the queen is building the agent from the draft. */
|
||||
building?: boolean;
|
||||
/** Message to show with a spinner while loading/designing. Null = no spinner. */
|
||||
loadingMessage?: string | null;
|
||||
/** Called when the user clicks Run. */
|
||||
onRun?: () => void;
|
||||
/** Called when the user clicks Pause. */
|
||||
@@ -103,13 +119,6 @@ function formatNodeId(id: string): string {
|
||||
return id.split("-").map(w => w.charAt(0).toUpperCase() + w.slice(1)).join(" ");
|
||||
}
|
||||
|
||||
function truncateLabel(label: string, availablePx: number, fontSize: number): string {
|
||||
const avgCharW = fontSize * 0.58;
|
||||
const maxChars = Math.floor(availablePx / avgCharW);
|
||||
if (label.length <= maxChars) return label;
|
||||
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
|
||||
}
|
||||
|
||||
/** Return the bounding-rect corner radius for a given flowchart shape. */
|
||||
/**
|
||||
* Render an ISO 5807 flowchart shape as an SVG element.
|
||||
@@ -142,13 +151,9 @@ function FlowchartShape({
|
||||
case "rectangle":
|
||||
return <rect x={x} y={y} width={w} height={h} rx={4} {...common} />;
|
||||
|
||||
case "rounded_rect":
|
||||
return <rect x={x} y={y} width={w} height={h} rx={12} {...common} />;
|
||||
|
||||
case "diamond": {
|
||||
const cx = x + w / 2;
|
||||
const cy = y + h / 2;
|
||||
// Keep diamond within bounding box
|
||||
return (
|
||||
<polygon
|
||||
points={`${cx},${y} ${x + w},${cy} ${cx},${y + h} ${x},${cy}`}
|
||||
@@ -172,18 +177,6 @@ function FlowchartShape({
|
||||
return <path d={d} {...common} />;
|
||||
}
|
||||
|
||||
case "multi_document": {
|
||||
const off = 3;
|
||||
const d = `M ${x} ${y + 4 + off} Q ${x} ${y + off}, ${x + 8} ${y + off} L ${x + w - 8 - off} ${y + off} Q ${x + w - off} ${y + off}, ${x + w - off} ${y + 4 + off} L ${x + w - off} ${y + h - 8} C ${x + (w - off) * 0.75} ${y + h + 2}, ${x + (w - off) * 0.25} ${y + h - 10}, ${x} ${y + h - 4} Z`;
|
||||
return (
|
||||
<g>
|
||||
<rect x={x + off * 2} y={y} width={w - off * 2} height={h - off} rx={4} fill={fill} stroke={stroke} strokeWidth={1.2} opacity={0.4} />
|
||||
<rect x={x + off} y={y + off / 2} width={w - off} height={h - off} rx={4} fill={fill} stroke={stroke} strokeWidth={1.2} opacity={0.6} />
|
||||
<path d={d} {...common} />
|
||||
</g>
|
||||
);
|
||||
}
|
||||
|
||||
case "subroutine": {
|
||||
const inset = 7;
|
||||
return (
|
||||
@@ -205,34 +198,6 @@ function FlowchartShape({
|
||||
);
|
||||
}
|
||||
|
||||
case "manual_input":
|
||||
return (
|
||||
<polygon
|
||||
points={`${x},${y + 10} ${x + w},${y} ${x + w},${y + h} ${x},${y + h}`}
|
||||
{...common}
|
||||
/>
|
||||
);
|
||||
|
||||
case "trapezoid": {
|
||||
const inset = 12;
|
||||
return (
|
||||
<polygon
|
||||
points={`${x},${y} ${x + w},${y} ${x + w - inset},${y + h} ${x + inset},${y + h}`}
|
||||
{...common}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
case "delay": {
|
||||
const d = `M ${x} ${y + 4} Q ${x} ${y}, ${x + 4} ${y} L ${x + w * 0.65} ${y} A ${w * 0.35} ${h / 2} 0 0 1 ${x + w * 0.65} ${y + h} L ${x + 4} ${y + h} Q ${x} ${y + h}, ${x} ${y + h - 4} Z`;
|
||||
return <path d={d} {...common} />;
|
||||
}
|
||||
|
||||
case "display": {
|
||||
const d = `M ${x + 16} ${y} L ${x + w * 0.65} ${y} A ${w * 0.35} ${h / 2} 0 0 1 ${x + w * 0.65} ${y + h} L ${x + 16} ${y + h} L ${x} ${y + h / 2} Z`;
|
||||
return <path d={d} {...common} />;
|
||||
}
|
||||
|
||||
case "cylinder": {
|
||||
const ry = 7;
|
||||
return (
|
||||
@@ -247,88 +212,6 @@ function FlowchartShape({
|
||||
);
|
||||
}
|
||||
|
||||
case "stored_data": {
|
||||
const d = `M ${x + 14} ${y} L ${x + w} ${y} A 10 ${h / 2} 0 0 0 ${x + w} ${y + h} L ${x + 14} ${y + h} A 10 ${h / 2} 0 0 1 ${x + 14} ${y} Z`;
|
||||
return <path d={d} {...common} />;
|
||||
}
|
||||
|
||||
case "internal_storage":
|
||||
return (
|
||||
<g>
|
||||
<rect x={x} y={y} width={w} height={h} rx={4} {...common} />
|
||||
<line x1={x + 10} y1={y} x2={x + 10} y2={y + h} stroke={stroke} strokeWidth={0.8} opacity={0.5} />
|
||||
<line x1={x} y1={y + 10} x2={x + w} y2={y + 10} stroke={stroke} strokeWidth={0.8} opacity={0.5} />
|
||||
</g>
|
||||
);
|
||||
|
||||
case "circle": {
|
||||
const r = Math.min(w, h) / 2 - 2;
|
||||
return <circle cx={x + w / 2} cy={y + h / 2} r={r} {...common} />;
|
||||
}
|
||||
|
||||
case "pentagon":
|
||||
return (
|
||||
<polygon
|
||||
points={`${x},${y} ${x + w},${y} ${x + w},${y + h * 0.6} ${x + w / 2},${y + h} ${x},${y + h * 0.6}`}
|
||||
{...common}
|
||||
/>
|
||||
);
|
||||
|
||||
case "triangle_inv":
|
||||
return (
|
||||
<polygon
|
||||
points={`${x},${y} ${x + w},${y} ${x + w / 2},${y + h}`}
|
||||
{...common}
|
||||
/>
|
||||
);
|
||||
|
||||
case "triangle":
|
||||
return (
|
||||
<polygon
|
||||
points={`${x + w / 2},${y} ${x + w},${y + h} ${x},${y + h}`}
|
||||
{...common}
|
||||
/>
|
||||
);
|
||||
|
||||
case "hourglass":
|
||||
return (
|
||||
<polygon
|
||||
points={`${x},${y} ${x + w},${y} ${x + w / 2},${y + h / 2} ${x + w},${y + h} ${x},${y + h} ${x + w / 2},${y + h / 2}`}
|
||||
{...common}
|
||||
/>
|
||||
);
|
||||
|
||||
case "circle_cross": {
|
||||
const r = Math.min(w, h) / 2 - 2;
|
||||
const cx = x + w / 2;
|
||||
const cy = y + h / 2;
|
||||
return (
|
||||
<g>
|
||||
<circle cx={cx} cy={cy} r={r} {...common} />
|
||||
<line x1={cx - r * 0.7} y1={cy - r * 0.7} x2={cx + r * 0.7} y2={cy + r * 0.7} stroke={stroke} strokeWidth={1} />
|
||||
<line x1={cx + r * 0.7} y1={cy - r * 0.7} x2={cx - r * 0.7} y2={cy + r * 0.7} stroke={stroke} strokeWidth={1} />
|
||||
</g>
|
||||
);
|
||||
}
|
||||
|
||||
case "circle_bar": {
|
||||
const r = Math.min(w, h) / 2 - 2;
|
||||
const cx = x + w / 2;
|
||||
const cy = y + h / 2;
|
||||
return (
|
||||
<g>
|
||||
<circle cx={cx} cy={cy} r={r} {...common} />
|
||||
<line x1={cx} y1={cy - r} x2={cx} y2={cy + r} stroke={stroke} strokeWidth={1} />
|
||||
<line x1={cx - r} y1={cy} x2={cx + r} y2={cy} stroke={stroke} strokeWidth={1} />
|
||||
</g>
|
||||
);
|
||||
}
|
||||
|
||||
case "flag": {
|
||||
const d = `M ${x} ${y} L ${x + w} ${y} L ${x + w - 8} ${y + h / 2} L ${x + w} ${y + h} L ${x} ${y + h} Z`;
|
||||
return <path d={d} {...common} />;
|
||||
}
|
||||
|
||||
default:
|
||||
return <rect x={x} y={y} width={w} height={h} rx={8} {...common} />;
|
||||
}
|
||||
@@ -355,13 +238,51 @@ function Tooltip({ node, style }: { node: DraftNode; style: React.CSSProperties
|
||||
);
|
||||
}
|
||||
|
||||
export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNodes, onRuntimeNodeClick, building, onRun, onPause, runState = "idle" }: DraftGraphProps) {
|
||||
export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchartMap, runtimeNodes, onRuntimeNodeClick, building, loadingMessage, onRun, onPause, runState = "idle" }: DraftGraphProps) {
|
||||
const [hoveredNode, setHoveredNode] = useState<string | null>(null);
|
||||
const [mousePos, setMousePos] = useState<{ x: number; y: number } | null>(null);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const runBtnRef = useRef<HTMLButtonElement>(null);
|
||||
const [containerW, setContainerW] = useState(484);
|
||||
const chrome = useDraftChromeColors();
|
||||
const triggerColors = useTriggerColors();
|
||||
|
||||
// Extract trigger nodes from runtimeNodes
|
||||
const triggerNodes = useMemo(
|
||||
() => (runtimeNodes ?? []).filter(n => n.nodeType === "trigger"),
|
||||
[runtimeNodes],
|
||||
);
|
||||
|
||||
// ── Entrance animation — fires when originalDraft becomes a new non-null value ──
|
||||
// This covers: agent loaded, build finished, queen modifies flowchart.
|
||||
// Tab switches remount via React key={activeWorker}, resetting all refs.
|
||||
const prevOriginalDraft = useRef<DraftGraphData | null>(null);
|
||||
const pendingAnimation = useRef(false);
|
||||
const [entrancePhase, setEntrancePhase] = useState<"idle" | "hidden" | "visible">("idle");
|
||||
|
||||
const nodes = draft?.nodes ?? [];
|
||||
|
||||
useLayoutEffect(() => {
|
||||
const prev = prevOriginalDraft.current;
|
||||
prevOriginalDraft.current = originalDraft ?? null;
|
||||
|
||||
// Detect a new non-null originalDraft (object identity — each API/SSE response is a fresh object)
|
||||
if (originalDraft && originalDraft !== prev) {
|
||||
pendingAnimation.current = true;
|
||||
}
|
||||
|
||||
// Fire when we have a pending animation, nodes are ready, and not mid-build
|
||||
if (pendingAnimation.current && nodes.length > 0 && !building) {
|
||||
pendingAnimation.current = false;
|
||||
setEntrancePhase("hidden");
|
||||
let raf1 = 0, raf2 = 0;
|
||||
raf1 = requestAnimationFrame(() => {
|
||||
raf2 = requestAnimationFrame(() => setEntrancePhase("visible"));
|
||||
});
|
||||
const t = setTimeout(() => setEntrancePhase("idle"), nodes.length * 120 + 1000);
|
||||
return () => { clearTimeout(t); cancelAnimationFrame(raf1); cancelAnimationFrame(raf2); };
|
||||
}
|
||||
}, [originalDraft, nodes.length, building]);
|
||||
|
||||
// Shift-to-pin tooltip
|
||||
const shiftHeld = useRef(false);
|
||||
@@ -463,7 +384,7 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
|
||||
const hasStatusOverlay = Object.keys(nodeStatuses).length > 0;
|
||||
|
||||
const { nodes, edges } = draft;
|
||||
const edges = draft?.edges ?? [];
|
||||
|
||||
const idxMap = useMemo(
|
||||
() => Object.fromEntries(nodes.map((n, i) => [n.id, i])),
|
||||
@@ -536,6 +457,11 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
layerGroups.forEach((group) => {
|
||||
maxCols = Math.max(maxCols, group.length);
|
||||
});
|
||||
// Ensure maxCols accommodates any parent's children fan-out
|
||||
// (prevents fan-out scaling from collapsing to zero)
|
||||
children.forEach((kids) => {
|
||||
maxCols = Math.max(maxCols, kids.length);
|
||||
});
|
||||
|
||||
// Compute node width — keep back-edge overflow out of node sizing so nodes
|
||||
// get full width. The viewBox is expanded later to fit back-edge curves.
|
||||
@@ -641,6 +567,17 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
}
|
||||
}
|
||||
|
||||
// Post-process: enforce minimum spacing within each layer
|
||||
for (const [, group] of layerGroups) {
|
||||
if (group.length <= 1) continue;
|
||||
const sorted = [...group].sort((a, b) => colPos[a] - colPos[b]);
|
||||
for (let j = 1; j < sorted.length; j++) {
|
||||
if (colPos[sorted[j]] < colPos[sorted[j - 1]] + 1) {
|
||||
colPos[sorted[j]] = colPos[sorted[j - 1]] + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert fractional column positions to pixel X positions
|
||||
const colSpacing = nodeW + GAP_X;
|
||||
const usedMin = Math.min(...colPos);
|
||||
@@ -656,25 +593,6 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
return { layers, nodeW, firstColX, nodeXPositions, backEdgeOverflow, maxContentRight };
|
||||
}, [nodes, forwardEdges, backEdges.length, containerW, flowchartMap, idxMap]);
|
||||
|
||||
if (nodes.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
<div className="px-4 pt-4 pb-2">
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">
|
||||
Draft
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex-1 flex items-center justify-center px-4">
|
||||
<p className="text-xs text-muted-foreground/60 text-center italic">
|
||||
No draft graph yet.
|
||||
<br />
|
||||
Describe your workflow to get started.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const { layers, nodeW, nodeXPositions, backEdgeOverflow, maxContentRight } = layout;
|
||||
|
||||
const maxLayer = nodes.length > 0 ? Math.max(...layers) : 0;
|
||||
@@ -803,22 +721,27 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
return { nodeYOffset: offsets, totalExtraY: totalExtra, groupBoxMaxX: maxGroupX };
|
||||
}, [nodes, maxLayer, flowchartMap, idxMap, layers, nodeXPositions, nodeW]);
|
||||
|
||||
// When triggers are present, push the entire draft graph down to make room
|
||||
const triggerOffsetY = triggerNodes.length > 0
|
||||
? TRIGGER_H + TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP + TRIGGER_CLEARANCE
|
||||
: 0;
|
||||
|
||||
const nodePos = (i: number) => ({
|
||||
x: nodeXPositions[i],
|
||||
y: TOP_Y + layers[i] * (NODE_H + GAP_Y) + nodeYOffset[i],
|
||||
y: TOP_Y + triggerOffsetY + layers[i] * (NODE_H + GAP_Y) + nodeYOffset[i],
|
||||
});
|
||||
|
||||
const svgHeight = TOP_Y + (maxLayer + 1) * NODE_H + maxLayer * GAP_Y + totalExtraY + 16;
|
||||
const svgHeight = TOP_Y + triggerOffsetY + (maxLayer + 1) * NODE_H + maxLayer * GAP_Y + totalExtraY + 16;
|
||||
|
||||
// Compute group areas for runtime node boundaries on the draft
|
||||
const groupAreas = useMemo(() => {
|
||||
if (!flowchartMap || !runtimeNodes?.length) return [];
|
||||
if (!flowchartMap) return [];
|
||||
const groups: { runtimeId: string; label: string; draftIds: string[] }[] = [];
|
||||
for (const [runtimeId, draftIds] of Object.entries(flowchartMap)) {
|
||||
groups.push({ runtimeId, label: formatNodeId(runtimeId), draftIds });
|
||||
}
|
||||
return groups;
|
||||
}, [flowchartMap, runtimeNodes]);
|
||||
}, [flowchartMap]);
|
||||
|
||||
// Legend
|
||||
const usedTypes = (() => {
|
||||
@@ -856,12 +779,27 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
? `M ${startX} ${y1} L ${toCenterX} ${y2}`
|
||||
: `M ${startX} ${y1} L ${startX} ${midY} L ${toCenterX} ${midY} L ${toCenterX} ${y2}`;
|
||||
|
||||
// Edge draw-in animation (stroke-dashoffset)
|
||||
const isAnimating = entrancePhase !== "idle";
|
||||
const pathLength = Math.abs(y2 - y1) + Math.abs(startX - toCenterX) + 1;
|
||||
const edgeDelay = 200 + i * 80;
|
||||
const edgeStyle: React.CSSProperties | undefined = isAnimating ? {
|
||||
strokeDasharray: pathLength,
|
||||
strokeDashoffset: entrancePhase === "hidden" ? pathLength : 0,
|
||||
transition: `stroke-dashoffset 400ms ease-in-out ${edgeDelay}ms`,
|
||||
} : undefined;
|
||||
const edgeEndStyle: React.CSSProperties | undefined = isAnimating ? {
|
||||
opacity: entrancePhase === "hidden" ? 0 : 1,
|
||||
transition: `opacity 100ms ease-out ${edgeDelay + 350}ms`,
|
||||
} : undefined;
|
||||
|
||||
return (
|
||||
<g key={`fwd-${i}`}>
|
||||
<path d={d} fill="none" stroke={chrome.edge} strokeWidth={1.2} />
|
||||
<path d={d} fill="none" stroke={chrome.edge} strokeWidth={1.2} style={edgeStyle} />
|
||||
<polygon
|
||||
points={`${toCenterX - 3},${y2 - 5} ${toCenterX + 3},${y2 - 5} ${toCenterX},${y2 - 1}`}
|
||||
fill={chrome.edgeArrow}
|
||||
style={edgeEndStyle}
|
||||
/>
|
||||
{edge.label && (
|
||||
<text
|
||||
@@ -871,6 +809,7 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
fontSize={9}
|
||||
fontStyle="italic"
|
||||
textAnchor="middle"
|
||||
style={edgeEndStyle}
|
||||
>
|
||||
{truncateLabel(edge.label, 80, 9)}
|
||||
</text>
|
||||
@@ -893,12 +832,26 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
|
||||
const path = `M ${startX} ${startY} C ${startX + r} ${startY}, ${curveX} ${startY}, ${curveX} ${startY - r} L ${curveX} ${endY + r} C ${curveX} ${endY}, ${endX + r} ${endY}, ${endX + 5} ${endY}`;
|
||||
|
||||
// Back-edge draw-in animation (starts after forward edges)
|
||||
const isAnimating = entrancePhase !== "idle";
|
||||
const backPathLength = Math.abs(curveX - startX) + Math.abs(startY - endY) + Math.abs(curveX - endX) + 20;
|
||||
const backDelay = nodes.length * 120 + 300 + i * 80;
|
||||
const backEdgeStyle: React.CSSProperties | undefined = isAnimating ? {
|
||||
strokeDashoffset: entrancePhase === "hidden" ? backPathLength : 0,
|
||||
transition: `stroke-dashoffset 400ms ease-in-out ${backDelay}ms`,
|
||||
} : undefined;
|
||||
const backEndStyle: React.CSSProperties | undefined = isAnimating ? {
|
||||
opacity: entrancePhase === "hidden" ? 0 : 1,
|
||||
transition: `opacity 100ms ease-out ${backDelay + 350}ms`,
|
||||
} : undefined;
|
||||
|
||||
return (
|
||||
<g key={`back-${i}`}>
|
||||
<path d={path} fill="none" stroke={chrome.backEdge} strokeWidth={1.2} strokeDasharray="4 3" />
|
||||
<path d={path} fill="none" stroke={chrome.backEdge} strokeWidth={1.2} strokeDasharray={isAnimating ? backPathLength : "4 3"} style={backEdgeStyle} />
|
||||
<polygon
|
||||
points={`${endX + 5},${endY - 2.5} ${endX + 5},${endY + 2.5} ${endX},${endY}`}
|
||||
fill={chrome.edge}
|
||||
style={backEndStyle}
|
||||
/>
|
||||
</g>
|
||||
);
|
||||
@@ -911,6 +864,131 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
pending: "",
|
||||
};
|
||||
|
||||
// ── Trigger node rendering ──
|
||||
|
||||
const triggerW = Math.min(nodeW, 180);
|
||||
|
||||
// Shared trigger pill X position (used by both node and edge renderers)
|
||||
const triggerPillX = (idx: number) => {
|
||||
const totalW = triggerNodes.length * triggerW + (triggerNodes.length - 1) * TRIGGER_PILL_GAP_X;
|
||||
return (containerW - totalW) / 2 + idx * (triggerW + TRIGGER_PILL_GAP_X);
|
||||
};
|
||||
|
||||
const renderTriggerNode = (node: GraphNode, triggerIdx: number) => {
|
||||
const icon = TRIGGER_ICONS[node.triggerType || ""] || "\u26A1";
|
||||
const isActive = node.status === "running" || node.status === "complete";
|
||||
const colors = isActive ? ACTIVE_TRIGGER_COLORS : triggerColors;
|
||||
const nextFireIn = node.triggerConfig?.next_fire_in as number | undefined;
|
||||
|
||||
const tx = triggerPillX(triggerIdx);
|
||||
const ty = TOP_Y;
|
||||
|
||||
const fontSize = triggerW < 140 ? 10.5 : 11.5;
|
||||
const displayLabel = truncateLabel(node.label, triggerW - TRIGGER_LABEL_INSET, fontSize);
|
||||
|
||||
// Countdown
|
||||
let countdownLabel: string | null = null;
|
||||
if (isActive && nextFireIn != null && nextFireIn > 0) {
|
||||
const h = Math.floor(nextFireIn / 3600);
|
||||
const m = Math.floor((nextFireIn % 3600) / 60);
|
||||
const s = Math.floor(nextFireIn % 60);
|
||||
countdownLabel = h > 0
|
||||
? `next in ${h}h ${String(m).padStart(2, "0")}m`
|
||||
: `next in ${m}m ${String(s).padStart(2, "0")}s`;
|
||||
}
|
||||
|
||||
const statusLabel = isActive ? "active" : "inactive";
|
||||
const statusColor = isActive ? "hsl(140,40%,50%)" : "hsl(210,20%,40%)";
|
||||
|
||||
return (
|
||||
<g
|
||||
key={node.id}
|
||||
onClick={() => onRuntimeNodeClick?.(node.id)}
|
||||
style={{ cursor: onRuntimeNodeClick ? "pointer" : "default" }}
|
||||
>
|
||||
<title>{node.label}</title>
|
||||
{/* Pill-shaped background */}
|
||||
<rect
|
||||
x={tx} y={ty}
|
||||
width={triggerW} height={TRIGGER_H}
|
||||
rx={TRIGGER_H / 2}
|
||||
fill={colors.bg}
|
||||
stroke={colors.border}
|
||||
strokeWidth={isActive ? 1.5 : 1}
|
||||
strokeDasharray={isActive ? undefined : "4 2"}
|
||||
/>
|
||||
{/* Icon */}
|
||||
<text
|
||||
x={tx + TRIGGER_ICON_X} y={ty + TRIGGER_H / 2}
|
||||
fill={colors.icon} fontSize={13}
|
||||
textAnchor="middle" dominantBaseline="middle"
|
||||
>
|
||||
{icon}
|
||||
</text>
|
||||
{/* Label */}
|
||||
<text
|
||||
x={tx + TRIGGER_LABEL_X} y={ty + TRIGGER_H / 2}
|
||||
fill={colors.text}
|
||||
fontSize={fontSize}
|
||||
fontWeight={500}
|
||||
dominantBaseline="middle"
|
||||
letterSpacing="0.01em"
|
||||
>
|
||||
{displayLabel}
|
||||
</text>
|
||||
{/* Countdown */}
|
||||
{countdownLabel && (
|
||||
<text
|
||||
x={tx + triggerW / 2} y={ty + TRIGGER_H + TRIGGER_TEXT_Y}
|
||||
fill={colors.text} fontSize={9}
|
||||
textAnchor="middle" fontStyle="italic" opacity={0.7}
|
||||
>
|
||||
{countdownLabel}
|
||||
</text>
|
||||
)}
|
||||
{/* Status */}
|
||||
<text
|
||||
x={tx + triggerW / 2} y={ty + TRIGGER_H + (countdownLabel ? TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP : TRIGGER_TEXT_Y)}
|
||||
fill={statusColor} fontSize={8.5}
|
||||
textAnchor="middle" opacity={0.8}
|
||||
>
|
||||
{statusLabel}
|
||||
</text>
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderTriggerEdge = (triggerIdx: number) => {
|
||||
if (nodes.length === 0) return null;
|
||||
const triggerNode = triggerNodes[triggerIdx];
|
||||
const runtimeTargetId = triggerNode?.next?.[0];
|
||||
const targetDraftId = runtimeTargetId
|
||||
? flowchartMap?.[runtimeTargetId]?.[0] ?? runtimeTargetId
|
||||
: draft?.entry_node;
|
||||
const targetIdx = targetDraftId ? idxMap[targetDraftId] ?? 0 : 0;
|
||||
const targetPos = nodePos(targetIdx);
|
||||
const targetX = targetPos.x + nodeW / 2;
|
||||
const targetY = targetPos.y;
|
||||
|
||||
const tx = triggerPillX(triggerIdx) + triggerW / 2;
|
||||
const ty = TOP_Y + TRIGGER_H + TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP + 4;
|
||||
|
||||
const midY = (ty + targetY) / 2;
|
||||
const d = Math.abs(tx - targetX) < 2
|
||||
? `M ${tx} ${ty} L ${targetX} ${targetY}`
|
||||
: `M ${tx} ${ty} L ${tx} ${midY} L ${targetX} ${midY} L ${targetX} ${targetY}`;
|
||||
|
||||
return (
|
||||
<g key={`trigger-edge-${triggerIdx}`}>
|
||||
<path d={d} fill="none" stroke={chrome.edge} strokeWidth={1.2} strokeDasharray="4 3" />
|
||||
<polygon
|
||||
points={`${targetX - 3},${targetY - 5} ${targetX + 3},${targetY - 5} ${targetX},${targetY - 1}`}
|
||||
fill={chrome.edgeArrow}
|
||||
/>
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderNode = (node: DraftNode, i: number) => {
|
||||
const pos = nodePos(i);
|
||||
const isHovered = hoveredNode === node.id;
|
||||
@@ -942,7 +1020,13 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
if (rect) setMousePos({ x: e.clientX - rect.left, y: e.clientY - rect.top });
|
||||
}}
|
||||
onMouseLeave={() => { if (!shiftHeld.current) { setHoveredNode(null); setMousePos(null); } }}
|
||||
style={{ cursor: "pointer" }}
|
||||
style={{
|
||||
cursor: "pointer",
|
||||
...(entrancePhase !== "idle" ? {
|
||||
opacity: entrancePhase === "hidden" ? 0 : 1,
|
||||
transition: `opacity 300ms ease-out ${i * 120}ms`,
|
||||
} : {}),
|
||||
}}
|
||||
>
|
||||
|
||||
<FlowchartShape
|
||||
@@ -982,6 +1066,30 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
);
|
||||
};
|
||||
|
||||
if (!draft || nodes.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
<div className="px-4 pt-3 pb-1.5 flex items-center gap-2">
|
||||
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">Draft</p>
|
||||
</div>
|
||||
<div className="flex-1 flex flex-col items-center justify-center gap-3">
|
||||
{loadingMessage ? (
|
||||
<>
|
||||
<Loader2 className="w-5 h-5 animate-spin text-muted-foreground/40" />
|
||||
<p className="text-xs text-muted-foreground/50">{loadingMessage}</p>
|
||||
</>
|
||||
) : (
|
||||
<p className="text-xs text-muted-foreground/60 text-center italic">
|
||||
No draft graph yet.
|
||||
<br />
|
||||
Describe your workflow to get started.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
{/* Header */}
|
||||
@@ -995,6 +1103,11 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
<Loader2 className="w-2.5 h-2.5 animate-spin" />
|
||||
building
|
||||
</span>
|
||||
) : loadingMessage ? (
|
||||
<span className="text-[9px] font-mono font-medium rounded px-1 py-0.5 leading-none border text-amber-500/60 border-amber-500/20 flex items-center gap-1">
|
||||
<Loader2 className="w-2.5 h-2.5 animate-spin" />
|
||||
updating
|
||||
</span>
|
||||
) : (
|
||||
<span className={`text-[9px] font-mono font-medium rounded px-1 py-0.5 leading-none border ${hasStatusOverlay ? "text-emerald-500/60 border-emerald-500/20" : "text-amber-500/60 border-amber-500/20"}`}>
|
||||
{hasStatusOverlay ? "live" : "planning"}
|
||||
@@ -1014,12 +1127,16 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseUp={handleMouseUp}
|
||||
onMouseLeave={handleMouseUp}
|
||||
className={`w-full h-full${building ? " opacity-30" : ""}`}
|
||||
style={{ cursor: dragging ? "grabbing" : "grab" }}
|
||||
className="w-full h-full"
|
||||
style={{
|
||||
opacity: building || loadingMessage ? 0.3 : 1,
|
||||
transition: building || loadingMessage ? "none" : "opacity 300ms ease-out",
|
||||
cursor: dragging ? "grabbing" : "grab",
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
width="100%"
|
||||
viewBox={`0 0 ${Math.max((maxContentRight ?? 0), groupBoxMaxX) + (backEdgeOverflow ?? 0)} ${totalH}`}
|
||||
viewBox={`0 0 ${Math.max((maxContentRight ?? 0), groupBoxMaxX, triggerNodes.length > 0 ? triggerPillX(triggerNodes.length - 1) + triggerW : 0) + (backEdgeOverflow ?? 0)} ${totalH}`}
|
||||
preserveAspectRatio="xMidYMin meet"
|
||||
className="select-none"
|
||||
style={{
|
||||
@@ -1103,6 +1220,11 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Trigger edges (dashed lines from trigger pills to first draft node) */}
|
||||
{triggerNodes.map((_, i) => renderTriggerEdge(i))}
|
||||
{/* Trigger pill nodes */}
|
||||
{triggerNodes.map((tn, i) => renderTriggerNode(tn, i))}
|
||||
|
||||
{forwardEdges.map((e, i) => renderEdge(e, i))}
|
||||
{backEdges.map((e, i) => renderBackEdge(e, i))}
|
||||
{nodes.map((n, i) => renderNode(n, i))}
|
||||
@@ -1141,6 +1263,15 @@ export default function DraftGraph({ draft, onNodeClick, flowchartMap, runtimeNo
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!building && loadingMessage && (
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<Loader2 className="w-6 h-6 animate-spin text-muted-foreground/40" />
|
||||
<p className="text-xs text-muted-foreground/50">{loadingMessage}</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Zoom controls */}
|
||||
<div className="absolute bottom-3 right-3 flex items-center gap-1 bg-card/80 backdrop-blur-sm border border-border/40 rounded-lg p-0.5 shadow-sm">
|
||||
<button
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { X, Cpu, Zap, Clock, RotateCcw, CheckCircle2, AlertCircle, Loader2, ChevronDown, ChevronRight, Copy, Check, Terminal, Wrench, BookOpen, GitBranch, Bot } from "lucide-react";
|
||||
import type { GraphNode, NodeStatus } from "./AgentGraph";
|
||||
import type { GraphNode, NodeStatus } from "./graph-types";
|
||||
import type { NodeSpec, ToolInfo, NodeCriteria } from "../api/types";
|
||||
import { graphsApi } from "../api/graphs";
|
||||
import { logsApi } from "../api/logs";
|
||||
@@ -28,6 +28,13 @@ export interface SubagentReport {
|
||||
status?: "running" | "complete" | "error";
|
||||
}
|
||||
|
||||
interface ContextUsage {
|
||||
usagePct: number;
|
||||
messageCount: number;
|
||||
estimatedTokens: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
interface NodeDetailPanelProps {
|
||||
node: GraphNode | null;
|
||||
nodeSpec?: NodeSpec | null;
|
||||
@@ -38,6 +45,7 @@ interface NodeDetailPanelProps {
|
||||
workerSessionId?: string | null;
|
||||
nodeLogs?: string[];
|
||||
actionPlan?: string;
|
||||
contextUsage?: ContextUsage;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
@@ -309,7 +317,7 @@ const tabs: { id: Tab; label: string; Icon: React.FC<{ className?: string }> }[]
|
||||
{ id: "subagents", label: "Subagents", Icon: ({ className }) => <Bot className={className} /> },
|
||||
];
|
||||
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, contextUsage, onClose }: NodeDetailPanelProps) {
|
||||
const [activeTab, setActiveTab] = useState<Tab>("overview");
|
||||
const [realTools, setRealTools] = useState<ToolInfo[] | null>(null);
|
||||
const [realCriteria, setRealCriteria] = useState<NodeCriteria | null>(null);
|
||||
@@ -389,6 +397,43 @@ export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagent
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Context window usage */}
|
||||
{contextUsage && (
|
||||
<div className="px-4 py-2 border-b border-border/20 flex-shrink-0">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="text-[10px] text-muted-foreground font-medium">Context</span>
|
||||
<span className="text-[10px] text-muted-foreground/70 ml-auto">
|
||||
{(contextUsage.estimatedTokens / 1000).toFixed(1)}k / {(contextUsage.maxTokens / 1000).toFixed(0)}k tokens
|
||||
</span>
|
||||
</div>
|
||||
<div className="w-full h-1.5 rounded-full bg-muted/50 overflow-hidden">
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500 ease-out"
|
||||
style={{
|
||||
width: `${Math.min(contextUsage.usagePct, 100)}%`,
|
||||
backgroundColor: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center gap-2 mt-1">
|
||||
<span className="text-[10px] text-muted-foreground/60">{contextUsage.messageCount} messages</span>
|
||||
<span className="text-[10px] font-medium ml-auto" style={{
|
||||
color: contextUsage.usagePct >= 90
|
||||
? "hsl(0,65%,55%)"
|
||||
: contextUsage.usagePct >= 70
|
||||
? "hsl(35,90%,55%)"
|
||||
: "hsl(45,95%,58%)",
|
||||
}}>
|
||||
{contextUsage.usagePct}%
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Tab bar */}
|
||||
<div className="flex border-b border-border/30 flex-shrink-0 px-2 pt-1 overflow-x-auto scrollbar-hide">
|
||||
{tabs.filter(t => t.id !== "subagents" || (nodeSpec?.sub_agents && nodeSpec.sub_agents.length > 0)).map(tab => (
|
||||
|
||||
@@ -0,0 +1,413 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { ChevronDown, ChevronUp, Cpu } from "lucide-react";
|
||||
import type { ChatMessage, ContextUsageEntry } from "@/components/ChatPanel";
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const workerColor = "hsl(220,60%,55%)";
|
||||
|
||||
const SUBAGENT_COLORS = [
|
||||
"hsl(220,60%,55%)",
|
||||
"hsl(260,50%,55%)",
|
||||
"hsl(180,50%,45%)",
|
||||
"hsl(30,70%,50%)",
|
||||
"hsl(340,55%,50%)",
|
||||
"hsl(150,45%,45%)",
|
||||
"hsl(45,80%,50%)",
|
||||
"hsl(290,45%,55%)",
|
||||
];
|
||||
|
||||
function colorForIndex(i: number): string {
|
||||
return SUBAGENT_COLORS[i % SUBAGENT_COLORS.length];
|
||||
}
|
||||
|
||||
function subagentLabel(nodeId: string): string {
|
||||
const parts = nodeId.split(":subagent:");
|
||||
const raw = parts.length >= 2 ? parts[1] : nodeId;
|
||||
return raw
|
||||
.replace(/:\d+$/, "") // strip instance suffix like ":3"
|
||||
.replace(/[_-]/g, " ")
|
||||
.replace(/\b\w/g, (c) => c.toUpperCase())
|
||||
.trim();
|
||||
}
|
||||
|
||||
function last<T>(arr: T[]): T | undefined {
|
||||
return arr[arr.length - 1];
|
||||
}
|
||||
|
||||
export interface SubagentGroup {
|
||||
nodeId: string;
|
||||
messages: ChatMessage[];
|
||||
contextUsage?: ContextUsageEntry;
|
||||
}
|
||||
|
||||
interface ParallelSubagentBubbleProps {
|
||||
groups: SubagentGroup[];
|
||||
groupId: string;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Thermometer — vertical context gauge on right edge of each pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool overlay — shown when a tool_status message is active (not all done)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ToolOverlay({
|
||||
toolName,
|
||||
color,
|
||||
visible,
|
||||
}: {
|
||||
toolName: string;
|
||||
color: string;
|
||||
visible: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
className="absolute inset-0 top-[22px] flex items-center justify-center transition-opacity duration-200 z-10"
|
||||
style={{
|
||||
background: "rgba(8,8,14,0.82)",
|
||||
opacity: visible ? 1 : 0,
|
||||
pointerEvents: visible ? "auto" : "none",
|
||||
}}
|
||||
>
|
||||
<div className="text-center px-3 py-2 rounded-md border" style={{ borderColor: `${color}40` }}>
|
||||
<div className="text-[10px] font-medium" style={{ color }}>
|
||||
{toolName}
|
||||
</div>
|
||||
<div className="text-[11px] mt-0.5" style={{ color }}>
|
||||
{visible ? "..." : "\u2713"}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Single tmux pane
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function MuxPane({
|
||||
group,
|
||||
index,
|
||||
label,
|
||||
isFocused,
|
||||
isZoomed,
|
||||
onClickTitle,
|
||||
}: {
|
||||
group: SubagentGroup;
|
||||
index: number;
|
||||
label: string;
|
||||
isFocused: boolean;
|
||||
isZoomed: boolean;
|
||||
onClickTitle: () => void;
|
||||
}) {
|
||||
const bodyRef = useRef<HTMLDivElement>(null);
|
||||
const stickRef = useRef(true);
|
||||
const color = colorForIndex(index);
|
||||
const pct = group.contextUsage?.usagePct ?? 0;
|
||||
|
||||
const streamMsgs = group.messages.filter((m) => m.type !== "tool_status");
|
||||
const latestContent = last(streamMsgs)?.content ?? "";
|
||||
const msgCount = streamMsgs.length;
|
||||
|
||||
// Detect active tool and finished state from latest tool_status
|
||||
const latestTool = last(
|
||||
group.messages.filter((m) => m.type === "tool_status")
|
||||
);
|
||||
let activeToolName = "";
|
||||
let toolRunning = false;
|
||||
let isFinished = false;
|
||||
if (latestTool) {
|
||||
try {
|
||||
const parsed = JSON.parse(latestTool.content);
|
||||
const tools: { name: string; done: boolean }[] = parsed.tools || [];
|
||||
const allDone = parsed.allDone as boolean | undefined;
|
||||
const running = tools.find((t) => !t.done);
|
||||
if (running) {
|
||||
activeToolName = running.name;
|
||||
toolRunning = true;
|
||||
}
|
||||
// Finished when all tools are done and one of them is set_output
|
||||
// or report_to_parent (terminal tool calls)
|
||||
if (allDone && tools.length > 0) {
|
||||
const hasTerminal = tools.some(
|
||||
(t) =>
|
||||
t.done &&
|
||||
(t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
if (hasTerminal) isFinished = true;
|
||||
}
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-scroll
|
||||
useEffect(() => {
|
||||
if (stickRef.current && bodyRef.current) {
|
||||
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
|
||||
}
|
||||
}, [latestContent]);
|
||||
|
||||
const handleScroll = () => {
|
||||
const el = bodyRef.current;
|
||||
if (!el) return;
|
||||
stickRef.current = el.scrollHeight - el.scrollTop - el.clientHeight < 30;
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col min-h-0 overflow-hidden relative transition-all duration-200"
|
||||
style={{
|
||||
borderWidth: 1,
|
||||
borderStyle: "solid",
|
||||
borderColor: isFocused && !isFinished ? `${color}60` : "transparent",
|
||||
opacity: isFinished ? 0.4 : isFocused || isZoomed ? 1 : 0.55,
|
||||
...(isZoomed
|
||||
? { gridColumn: "1 / -1", gridRow: "1 / -1", zIndex: 10 }
|
||||
: {}),
|
||||
}}
|
||||
>
|
||||
{/* Title bar */}
|
||||
<div
|
||||
className="flex items-center gap-1.5 px-2 py-[3px] flex-shrink-0 cursor-pointer select-none"
|
||||
style={{ background: "#0e0e16", borderBottom: "1px solid #1a1a2a" }}
|
||||
onClick={onClickTitle}
|
||||
>
|
||||
{isFinished ? (
|
||||
<span className="text-[8px] flex-shrink-0 leading-none" style={{ color: "#4a4" }}>✓</span>
|
||||
) : (
|
||||
<div
|
||||
className="w-[6px] h-[6px] rounded-full flex-shrink-0"
|
||||
style={{ background: color }}
|
||||
/>
|
||||
)}
|
||||
<span className="text-[9px] flex-shrink-0" style={{ color: isFinished ? "#555" : color }}>
|
||||
{label}
|
||||
</span>
|
||||
<span className="flex-1" />
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{msgCount}
|
||||
</span>
|
||||
<div
|
||||
className="w-[36px] h-[3px] rounded-full overflow-hidden flex-shrink-0"
|
||||
style={{ background: "#1a1a2a" }}
|
||||
>
|
||||
<div
|
||||
className="h-full rounded-full transition-all duration-500"
|
||||
style={{
|
||||
width: `${Math.min(pct, 100)}%`,
|
||||
backgroundColor:
|
||||
pct >= 80 ? "hsl(0,65%,55%)" : pct >= 50 ? "hsl(35,90%,55%)" : color,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
|
||||
{pct}%
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Body */}
|
||||
<div
|
||||
ref={bodyRef}
|
||||
onScroll={handleScroll}
|
||||
className="flex-1 min-h-0 overflow-y-auto px-2 py-1 text-[10px] leading-[1.7]"
|
||||
style={{ background: "#08080e", color: "#555", fontFamily: "monospace" }}
|
||||
>
|
||||
{latestContent ? (
|
||||
<div style={{ color: "#ccc" }}>
|
||||
<MarkdownContent content={latestContent} />
|
||||
</div>
|
||||
) : (
|
||||
<span style={{ color: "#333" }}>waiting...</span>
|
||||
)}
|
||||
{/* Blinking cursor — hidden when finished */}
|
||||
{!isFinished && (
|
||||
<span
|
||||
className="inline-block w-[6px] h-[11px] align-middle ml-0.5"
|
||||
style={{
|
||||
background: color,
|
||||
animation: "cursorBlink 1s step-end infinite",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Tool overlay */}
|
||||
<ToolOverlay
|
||||
toolName={activeToolName}
|
||||
color={color}
|
||||
visible={toolRunning}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const ParallelSubagentBubble = memo(
|
||||
function ParallelSubagentBubble({ groups }: ParallelSubagentBubbleProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [zoomedIdx, setZoomedIdx] = useState<number | null>(null);
|
||||
|
||||
// Labels with instance numbers for duplicates
|
||||
const labels: string[] = (() => {
|
||||
const countByBase = new Map<string, number>();
|
||||
const bases = groups.map((g) => subagentLabel(g.nodeId));
|
||||
for (const b of bases)
|
||||
countByBase.set(b, (countByBase.get(b) ?? 0) + 1);
|
||||
const idxByBase = new Map<string, number>();
|
||||
return bases.map((b) => {
|
||||
if ((countByBase.get(b) ?? 1) <= 1) return b;
|
||||
const idx = (idxByBase.get(b) ?? 0) + 1;
|
||||
idxByBase.set(b, idx);
|
||||
return `${b} #${idx}`;
|
||||
});
|
||||
})();
|
||||
|
||||
// Latest-active pane
|
||||
const latestIdx = groups.reduce<number>((best, g, i) => {
|
||||
const filtered = g.messages.filter((m) => m.type !== "tool_status");
|
||||
const lm = last(filtered);
|
||||
if (!lm) return best;
|
||||
if (best < 0) return i;
|
||||
const bm = last(
|
||||
groups[best].messages.filter((m) => m.type !== "tool_status")
|
||||
);
|
||||
if (!bm) return i;
|
||||
return (lm.createdAt ?? 0) >= (bm.createdAt ?? 0) ? i : best;
|
||||
}, -1);
|
||||
|
||||
// Per-group finished detection (same logic as MuxPane)
|
||||
const finishedFlags = groups.map((g) => {
|
||||
const lt = last(g.messages.filter((m) => m.type === "tool_status"));
|
||||
if (!lt) return false;
|
||||
try {
|
||||
const p = JSON.parse(lt.content);
|
||||
const tools: { name: string; done: boolean }[] = p.tools || [];
|
||||
if (!p.allDone || tools.length === 0) return false;
|
||||
return tools.some(
|
||||
(t) => t.done && (t.name === "set_output" || t.name === "report_to_parent")
|
||||
);
|
||||
} catch { return false; }
|
||||
});
|
||||
const activeCount = finishedFlags.filter((f) => !f).length;
|
||||
|
||||
if (groups.length === 0) return null;
|
||||
|
||||
// Grid sizing: 2 columns, auto rows capped at a fixed height
|
||||
const rows = Math.ceil(groups.length / 2);
|
||||
const gridHeight = expanded
|
||||
? Math.min(rows * 200, 480)
|
||||
: Math.min(rows * 100, 240);
|
||||
|
||||
return (
|
||||
<div className="flex gap-3">
|
||||
{/* Left icon */}
|
||||
<div
|
||||
className="flex-shrink-0 w-7 h-7 rounded-xl flex items-center justify-center mt-1"
|
||||
style={{
|
||||
backgroundColor: `${workerColor}18`,
|
||||
border: `1.5px solid ${workerColor}35`,
|
||||
}}
|
||||
>
|
||||
<Cpu className="w-3.5 h-3.5" style={{ color: workerColor }} />
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0 max-w-[90%]">
|
||||
{/* Header */}
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="font-medium text-xs" style={{ color: workerColor }}>
|
||||
{groups.length === 1 ? "Sub-agent" : "Parallel Agents"}
|
||||
</span>
|
||||
<span className="text-[10px] font-medium px-1.5 py-0.5 rounded-md bg-muted text-muted-foreground">
|
||||
{activeCount > 0 ? `${activeCount} running` : `${groups.length} done`}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => {
|
||||
setExpanded((v) => !v);
|
||||
setZoomedIdx(null);
|
||||
}}
|
||||
className="ml-auto text-muted-foreground/60 hover:text-muted-foreground transition-colors p-0.5 rounded"
|
||||
title={expanded ? "Collapse" : "Expand"}
|
||||
>
|
||||
{expanded ? (
|
||||
<ChevronUp className="w-3.5 h-3.5" />
|
||||
) : (
|
||||
<ChevronDown className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Mux frame */}
|
||||
<div
|
||||
className="rounded-lg overflow-hidden"
|
||||
style={{
|
||||
border: "2px solid #1a1a2a",
|
||||
background: "#08080e",
|
||||
}}
|
||||
>
|
||||
{/* Grid */}
|
||||
<div
|
||||
className="grid gap-px"
|
||||
style={{
|
||||
gridTemplateColumns:
|
||||
groups.length === 1 ? "1fr" : "1fr 1fr",
|
||||
gridTemplateRows: `repeat(${rows}, 1fr)`,
|
||||
height: gridHeight,
|
||||
background: "#111",
|
||||
}}
|
||||
>
|
||||
{groups.map((group, i) => (
|
||||
<MuxPane
|
||||
key={group.nodeId}
|
||||
group={group}
|
||||
index={i}
|
||||
label={labels[i]}
|
||||
isFocused={latestIdx === i}
|
||||
isZoomed={zoomedIdx === i}
|
||||
onClickTitle={() =>
|
||||
setZoomedIdx(zoomedIdx === i ? null : i)
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
(prev, next) =>
|
||||
prev.groupId === next.groupId &&
|
||||
prev.groups.length === next.groups.length &&
|
||||
prev.groups.every(
|
||||
(g, i) =>
|
||||
g.nodeId === next.groups[i].nodeId &&
|
||||
g.messages.length === next.groups[i].messages.length &&
|
||||
last(g.messages)?.content === last(next.groups[i].messages)?.content &&
|
||||
g.contextUsage?.usagePct === next.groups[i].contextUsage?.usagePct
|
||||
)
|
||||
);
|
||||
|
||||
export default ParallelSubagentBubble;
|
||||
|
||||
// Injected as a global style (keyframes can't be inline)
|
||||
if (typeof document !== "undefined") {
|
||||
const id = "parallel-subagent-keyframes";
|
||||
if (!document.getElementById(id)) {
|
||||
const style = document.createElement("style");
|
||||
style.id = id;
|
||||
style.textContent = `
|
||||
@keyframes cursorBlink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
|
||||
@keyframes thermoPulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.4; } }
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
import { memo, useState } from "react";
|
||||
import { Play, Pause, Loader2, CheckCircle2 } from "lucide-react";
|
||||
import type { RunButtonProps } from "./graph-types";
|
||||
|
||||
export const RunButton = memo(function RunButton({ runState, disabled, onRun, onPause, btnRef }: RunButtonProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
const showPause = runState === "running" && hovered;
|
||||
|
||||
return (
|
||||
<button
|
||||
ref={btnRef}
|
||||
onClick={runState === "running" ? onPause : onRun}
|
||||
disabled={runState === "deploying" || disabled}
|
||||
onMouseEnter={() => setHovered(true)}
|
||||
onMouseLeave={() => setHovered(false)}
|
||||
className={`flex items-center gap-1.5 px-2.5 py-1 rounded-md text-[11px] font-semibold transition-all duration-200 ${
|
||||
showPause
|
||||
? "bg-amber-500/15 text-amber-400 border border-amber-500/40 hover:bg-amber-500/25 active:scale-95 cursor-pointer"
|
||||
: runState === "running"
|
||||
? "bg-green-500/15 text-green-400 border border-green-500/30 cursor-pointer"
|
||||
: runState === "deploying"
|
||||
? "bg-primary/10 text-primary border border-primary/20 cursor-default"
|
||||
: disabled
|
||||
? "bg-muted/30 text-muted-foreground/40 border border-border/20 cursor-not-allowed"
|
||||
: "bg-primary/10 text-primary border border-primary/20 hover:bg-primary/20 hover:border-primary/40 active:scale-95"
|
||||
}`}
|
||||
>
|
||||
{runState === "deploying" ? (
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
) : showPause ? (
|
||||
<Pause className="w-3 h-3 fill-current" />
|
||||
) : runState === "running" ? (
|
||||
<CheckCircle2 className="w-3 h-3" />
|
||||
) : (
|
||||
<Play className="w-3 h-3 fill-current" />
|
||||
)}
|
||||
{runState === "deploying" ? "Deploying\u2026" : showPause ? "Pause" : runState === "running" ? "Running" : "Run"}
|
||||
</button>
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,28 @@
|
||||
export type NodeStatus = "running" | "complete" | "pending" | "error" | "looping";
|
||||
|
||||
export type NodeType = "execution" | "trigger";
|
||||
|
||||
export interface GraphNode {
|
||||
id: string;
|
||||
label: string;
|
||||
status: NodeStatus;
|
||||
nodeType?: NodeType;
|
||||
triggerType?: string;
|
||||
triggerConfig?: Record<string, unknown>;
|
||||
next?: string[];
|
||||
backEdges?: string[];
|
||||
iterations?: number;
|
||||
maxIterations?: number;
|
||||
statusLabel?: string;
|
||||
edgeLabels?: Record<string, string>;
|
||||
}
|
||||
|
||||
export type RunState = "idle" | "deploying" | "running";
|
||||
|
||||
export interface RunButtonProps {
|
||||
runState: RunState;
|
||||
disabled: boolean;
|
||||
onRun: () => void;
|
||||
onPause: () => void;
|
||||
btnRef: React.Ref<HTMLButtonElement>;
|
||||
}
|
||||
@@ -196,6 +196,102 @@ describe("sseEventToChatMessage", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("different inner_turn values produce different message IDs", () => {
|
||||
const e1 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "first response", iteration: 0, inner_turn: 0 },
|
||||
});
|
||||
const e2 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "after tool call", iteration: 0, inner_turn: 1 },
|
||||
});
|
||||
const r1 = sseEventToChatMessage(e1, "t");
|
||||
const r2 = sseEventToChatMessage(e2, "t");
|
||||
expect(r1!.id).not.toBe(r2!.id);
|
||||
});
|
||||
|
||||
it("same inner_turn produces same ID (streaming upsert within one LLM call)", () => {
|
||||
const e1 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "partial", iteration: 0, inner_turn: 1 },
|
||||
});
|
||||
const e2 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "partial response", iteration: 0, inner_turn: 1 },
|
||||
});
|
||||
expect(sseEventToChatMessage(e1, "t")!.id).toBe(
|
||||
sseEventToChatMessage(e2, "t")!.id,
|
||||
);
|
||||
});
|
||||
|
||||
it("absent inner_turn produces same ID as inner_turn=0 (backward compat)", () => {
|
||||
const withField = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "hello", iteration: 2, inner_turn: 0 },
|
||||
});
|
||||
const withoutField = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "hello", iteration: 2 },
|
||||
});
|
||||
expect(sseEventToChatMessage(withField, "t")!.id).toBe(
|
||||
sseEventToChatMessage(withoutField, "t")!.id,
|
||||
);
|
||||
});
|
||||
|
||||
it("inner_turn=0 produces no suffix (matches old ID format)", () => {
|
||||
const event = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "hello", iteration: 3, inner_turn: 0 },
|
||||
});
|
||||
const result = sseEventToChatMessage(event, "t");
|
||||
expect(result!.id).toBe("stream-exec-1-3-queen");
|
||||
});
|
||||
|
||||
it("inner_turn>0 adds -t suffix to ID", () => {
|
||||
const event = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "hello", iteration: 3, inner_turn: 2 },
|
||||
});
|
||||
const result = sseEventToChatMessage(event, "t");
|
||||
expect(result!.id).toBe("stream-exec-1-3-t2-queen");
|
||||
});
|
||||
|
||||
it("llm_text_delta also uses inner_turn for distinct IDs", () => {
|
||||
const e1 = makeEvent({
|
||||
type: "llm_text_delta",
|
||||
node_id: "research",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "first", inner_turn: 0 },
|
||||
});
|
||||
const e2 = makeEvent({
|
||||
type: "llm_text_delta",
|
||||
node_id: "research",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "second", inner_turn: 1 },
|
||||
});
|
||||
const r1 = sseEventToChatMessage(e1, "t");
|
||||
const r2 = sseEventToChatMessage(e2, "t");
|
||||
expect(r1!.id).not.toBe(r2!.id);
|
||||
expect(r1!.id).toBe("stream-exec-1-research");
|
||||
expect(r2!.id).toBe("stream-exec-1-t1-research");
|
||||
});
|
||||
|
||||
it("uses timestamp fallback when both turnId and execution_id are null", () => {
|
||||
const event = makeEvent({
|
||||
type: "client_output_delta",
|
||||
|
||||
@@ -56,10 +56,15 @@ export function sseEventToChatMessage(
|
||||
const iterTid = iter != null ? String(iter) : tid;
|
||||
const iterIdKey = eid && iterTid ? `${eid}-${iterTid}` : eid || iterTid || `t-${Date.now()}`;
|
||||
|
||||
// Distinguish multiple LLM calls within the same iteration (inner tool loop).
|
||||
// inner_turn=0 (or absent) produces no suffix for backward compat.
|
||||
const innerTurn = event.data?.inner_turn as number | undefined;
|
||||
const innerSuffix = innerTurn != null && innerTurn > 0 ? `-t${innerTurn}` : "";
|
||||
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
if (!snapshot) return null;
|
||||
if (!snapshot.trim()) return null;
|
||||
return {
|
||||
id: `stream-${iterIdKey}-${event.node_id}`,
|
||||
id: `stream-${iterIdKey}${innerSuffix}-${event.node_id}`,
|
||||
agent: agentDisplayName || event.node_id || "Agent",
|
||||
agentColor: "",
|
||||
content: snapshot,
|
||||
@@ -67,6 +72,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -91,10 +98,13 @@ export function sseEventToChatMessage(
|
||||
}
|
||||
|
||||
case "llm_text_delta": {
|
||||
const llmInnerTurn = event.data?.inner_turn as number | undefined;
|
||||
const llmInnerSuffix = llmInnerTurn != null && llmInnerTurn > 0 ? `-t${llmInnerTurn}` : "";
|
||||
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
if (!snapshot) return null;
|
||||
if (!snapshot.trim()) return null;
|
||||
return {
|
||||
id: `stream-${idKey}-${event.node_id}`,
|
||||
id: `stream-${idKey}${llmInnerSuffix}-${event.node_id}`,
|
||||
agent: event.node_id || "Agent",
|
||||
agentColor: "",
|
||||
content: snapshot,
|
||||
@@ -102,6 +112,8 @@ export function sseEventToChatMessage(
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { GraphTopology, NodeSpec } from "@/api/types";
|
||||
import type { GraphNode, NodeStatus } from "@/components/AgentGraph";
|
||||
import type { GraphNode, NodeStatus } from "@/components/graph-types";
|
||||
|
||||
/**
|
||||
* Convert a backend GraphTopology (nodes + edges + entry_node) into
|
||||
* the GraphNode[] shape that AgentGraph renders.
|
||||
* the GraphNode[] shape that DraftGraph renders.
|
||||
*
|
||||
* Four jobs:
|
||||
* 1. Synthesize trigger nodes from non-manual entry_points
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
// ── Shared graph utilities ──
|
||||
// Common helpers used by both AgentGraph and DraftGraph.
|
||||
// AgentGraph still has its own copies for now (separate cleanup PR).
|
||||
|
||||
/** Read a CSS custom property value (space-separated HSL components). */
|
||||
export function cssVar(name: string): string {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue(name).trim();
|
||||
}
|
||||
|
||||
/** Truncate label to fit within `availablePx` at the given fontSize. */
|
||||
export function truncateLabel(label: string, availablePx: number, fontSize: number): string {
|
||||
const avgCharW = fontSize * 0.58;
|
||||
const maxChars = Math.floor(availablePx / avgCharW);
|
||||
if (label.length <= maxChars) return label;
|
||||
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
|
||||
}
|
||||
|
||||
// ── Trigger styling ──
|
||||
|
||||
export type TriggerColorSet = { bg: string; border: string; text: string; icon: string };
|
||||
|
||||
export function buildTriggerColors(): TriggerColorSet {
|
||||
const bg = cssVar("--trigger-bg") || "210 25% 14%";
|
||||
const border = cssVar("--trigger-border") || "210 30% 30%";
|
||||
const text = cssVar("--trigger-text") || "210 30% 65%";
|
||||
const icon = cssVar("--trigger-icon") || "210 40% 55%";
|
||||
return {
|
||||
bg: `hsl(${bg})`,
|
||||
border: `hsl(${border})`,
|
||||
text: `hsl(${text})`,
|
||||
icon: `hsl(${icon})`,
|
||||
};
|
||||
}
|
||||
|
||||
export const ACTIVE_TRIGGER_COLORS: TriggerColorSet = {
|
||||
bg: "hsl(210,30%,18%)",
|
||||
border: "hsl(210,50%,50%)",
|
||||
text: "hsl(210,40%,75%)",
|
||||
icon: "hsl(210,60%,65%)",
|
||||
};
|
||||
|
||||
export const TRIGGER_ICONS: Record<string, string> = {
|
||||
webhook: "\u26A1", // lightning bolt
|
||||
timer: "\u23F1", // stopwatch
|
||||
api: "\u2192", // right arrow
|
||||
event: "\u223F", // sine wave
|
||||
};
|
||||
|
||||
/** Format a cron expression into a human-readable schedule label. */
|
||||
export function cronToLabel(cron: string): string {
|
||||
const parts = cron.trim().split(/\s+/);
|
||||
if (parts.length !== 5) return cron;
|
||||
const [min, hour, dom, mon, dow] = parts;
|
||||
|
||||
// */N * * * * -> "Every Nm"
|
||||
if (min.startsWith("*/") && hour === "*" && dom === "*" && mon === "*" && dow === "*") {
|
||||
return `Every ${min.slice(2)}m`;
|
||||
}
|
||||
// 0 */N * * * -> "Every Nh"
|
||||
if (min === "0" && hour.startsWith("*/") && dom === "*" && mon === "*" && dow === "*") {
|
||||
return `Every ${hour.slice(2)}h`;
|
||||
}
|
||||
// 0 H * * * -> "Daily at Ham/pm"
|
||||
if (dom === "*" && mon === "*" && dow === "*" && !min.includes("*") && !hour.includes("*")) {
|
||||
const h = parseInt(hour, 10);
|
||||
const m = parseInt(min, 10);
|
||||
const suffix = h >= 12 ? "PM" : "AM";
|
||||
const h12 = h % 12 || 12;
|
||||
return m === 0 ? `Daily at ${h12}${suffix}` : `Daily at ${h12}:${String(m).padStart(2, "0")}${suffix}`;
|
||||
}
|
||||
return cron;
|
||||
}
|
||||
|
||||
/** Theme-reactive hook for inactive trigger colors. */
|
||||
export function useTriggerColors(): TriggerColorSet {
|
||||
const [colors, setColors] = useState<TriggerColorSet>(buildTriggerColors);
|
||||
|
||||
useEffect(() => {
|
||||
const rebuild = () => setColors(buildTriggerColors());
|
||||
const obs = new MutationObserver(rebuild);
|
||||
obs.observe(document.documentElement, { attributes: true, attributeFilter: ["class", "style"] });
|
||||
return () => obs.disconnect();
|
||||
}, []);
|
||||
|
||||
return colors;
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
import type { ChatMessage } from "@/components/ChatPanel";
|
||||
import type { GraphNode } from "@/components/AgentGraph";
|
||||
import type { GraphNode } from "@/components/graph-types";
|
||||
|
||||
export const TAB_STORAGE_KEY = "hive:workspace-tabs";
|
||||
|
||||
|
||||
@@ -27,7 +27,14 @@ export default function MyAgents() {
|
||||
agentsApi
|
||||
.discover()
|
||||
.then((result) => {
|
||||
setAgents(result["Your Agents"] || []);
|
||||
const entries = result["Your Agents"] || [];
|
||||
entries.sort((a, b) => {
|
||||
if (!a.last_active && !b.last_active) return 0;
|
||||
if (!a.last_active) return 1;
|
||||
if (!b.last_active) return -1;
|
||||
return b.last_active.localeCompare(a.last_active);
|
||||
});
|
||||
setAgents(entries);
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err.message || "Failed to load agents");
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useState, useCallback, useRef, useEffect, useMemo } from "react";
|
||||
import ReactDOM from "react-dom";
|
||||
import { useSearchParams, useNavigate } from "react-router-dom";
|
||||
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X } from "lucide-react";
|
||||
import AgentGraph, { type GraphNode, type NodeStatus } from "@/components/AgentGraph";
|
||||
import type { GraphNode, NodeStatus } from "@/components/graph-types";
|
||||
import DraftGraph from "@/components/DraftGraph";
|
||||
import ChatPanel, { type ChatMessage } from "@/components/ChatPanel";
|
||||
import TopBar from "@/components/TopBar";
|
||||
@@ -17,6 +17,7 @@ import { useMultiSSE } from "@/hooks/use-sse";
|
||||
import type { LiveSession, AgentEvent, DiscoverEntry, NodeSpec, DraftGraph as DraftGraphData } from "@/api/types";
|
||||
import { sseEventToChatMessage, formatAgentDisplayName } from "@/lib/chat-helpers";
|
||||
import { topologyToGraphNodes } from "@/lib/graph-converter";
|
||||
import { cronToLabel } from "@/lib/graphUtils";
|
||||
import { ApiError } from "@/api/client";
|
||||
|
||||
const makeId = () => Math.random().toString(36).slice(2, 9);
|
||||
@@ -113,7 +114,13 @@ function NewTabPopover({ open, onClose, anchorRef, discoverAgents, onFromScratch
|
||||
useEffect(() => {
|
||||
if (open && anchorRef.current) {
|
||||
const rect = anchorRef.current.getBoundingClientRect();
|
||||
setPos({ top: rect.bottom + 4, left: rect.left });
|
||||
const POPUP_WIDTH = 240; // w-60 = 15rem = 240px
|
||||
const overflows = rect.left + POPUP_WIDTH > window.innerWidth - 8;
|
||||
console.log("Anchor rect:", rect, "Overflows:", overflows);
|
||||
setPos({
|
||||
top: rect.bottom + 4,
|
||||
left: overflows ? rect.right - POPUP_WIDTH : rect.left,
|
||||
});
|
||||
}
|
||||
}, [open, anchorRef]);
|
||||
|
||||
@@ -245,6 +252,10 @@ function truncate(s: string, max: number): string {
|
||||
type SessionRestoreResult = {
|
||||
messages: ChatMessage[];
|
||||
restoredPhase: "planning" | "building" | "staging" | "running" | null;
|
||||
/** Last flowchart map from events — used to restore flowchart overlay on cold resume. */
|
||||
flowchartMap: Record<string, string[]> | null;
|
||||
/** Last original draft from events — used to restore flowchart overlay on cold resume. */
|
||||
originalDraft: DraftGraphData | null;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -261,6 +272,8 @@ async function restoreSessionMessages(
|
||||
if (events.length > 0) {
|
||||
const messages: ChatMessage[] = [];
|
||||
let runningPhase: ChatMessage["phase"] = undefined;
|
||||
let flowchartMap: Record<string, string[]> | null = null;
|
||||
let originalDraft: DraftGraphData | null = null;
|
||||
for (const evt of events) {
|
||||
// Track phase transitions so each message gets the phase it was created in
|
||||
const p = evt.type === "queen_phase_changed" ? evt.data?.phase as string
|
||||
@@ -269,6 +282,12 @@ async function restoreSessionMessages(
|
||||
if (p && ["planning", "building", "staging", "running"].includes(p)) {
|
||||
runningPhase = p as ChatMessage["phase"];
|
||||
}
|
||||
// Track last flowchart state for cold restore
|
||||
if (evt.type === "flowchart_map_updated" && evt.data) {
|
||||
const mapData = evt.data as { map?: Record<string, string[]>; original_draft?: DraftGraphData };
|
||||
flowchartMap = mapData.map ?? null;
|
||||
originalDraft = mapData.original_draft ?? null;
|
||||
}
|
||||
const msg = sseEventToChatMessage(evt, thread, agentDisplayName);
|
||||
if (!msg) continue;
|
||||
if (evt.stream_id === "queen") {
|
||||
@@ -277,12 +296,12 @@ async function restoreSessionMessages(
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
return { messages, restoredPhase: runningPhase ?? null };
|
||||
return { messages, restoredPhase: runningPhase ?? null, flowchartMap, originalDraft };
|
||||
}
|
||||
} catch {
|
||||
// Event log not available — session will start fresh.
|
||||
}
|
||||
return { messages: [], restoredPhase: null };
|
||||
return { messages: [], restoredPhase: null, flowchartMap: null, originalDraft: null };
|
||||
}
|
||||
|
||||
// --- Per-agent backend state (consolidated) ---
|
||||
@@ -321,6 +340,8 @@ interface AgentBackendState {
|
||||
workerIsTyping: boolean;
|
||||
llmSnapshots: Record<string, string>;
|
||||
activeToolCalls: Record<string, { name: string; done: boolean; streamId: string }>;
|
||||
/** True while save_agent_draft tool is running (between tool_call_started and draft_graph_updated) */
|
||||
designingDraft: boolean;
|
||||
/** Agent folder path — set after scaffolding, used for credential queries */
|
||||
agentPath: string | null;
|
||||
/** Structured question text from ask_user with options */
|
||||
@@ -331,6 +352,8 @@ interface AgentBackendState {
|
||||
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
|
||||
/** Whether the pending question came from queen or worker */
|
||||
pendingQuestionSource: "queen" | "worker" | null;
|
||||
/** Per-node context window usage (from context_usage_updated events) */
|
||||
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
|
||||
}
|
||||
|
||||
function defaultAgentState(): AgentBackendState {
|
||||
@@ -347,6 +370,7 @@ function defaultAgentState(): AgentBackendState {
|
||||
workerInputMessageId: null,
|
||||
queenBuilding: false,
|
||||
queenPhase: "planning",
|
||||
designingDraft: false,
|
||||
draftGraph: null,
|
||||
originalDraft: null,
|
||||
flowchartMap: null,
|
||||
@@ -367,6 +391,7 @@ function defaultAgentState(): AgentBackendState {
|
||||
pendingOptions: null,
|
||||
pendingQuestions: null,
|
||||
pendingQuestionSource: null,
|
||||
contextUsage: {},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -548,9 +573,46 @@ export default function Workspace() {
|
||||
const [dismissedBanner, setDismissedBanner] = useState<string | null>(null);
|
||||
const [selectedNode, setSelectedNode] = useState<GraphNode | null>(null);
|
||||
const [triggerTaskDraft, setTriggerTaskDraft] = useState("");
|
||||
const [triggerCronDraft, setTriggerCronDraft] = useState("");
|
||||
const [triggerTaskSaving, setTriggerTaskSaving] = useState(false);
|
||||
const [triggerScheduleSaving, setTriggerScheduleSaving] = useState(false);
|
||||
const [triggerCronSaved, setTriggerCronSaved] = useState(false);
|
||||
const [triggerTaskSaved, setTriggerTaskSaved] = useState(false);
|
||||
const [newTabOpen, setNewTabOpen] = useState(false);
|
||||
const newTabBtnRef = useRef<HTMLButtonElement>(null);
|
||||
const [graphPanelPct, setGraphPanelPct] = useState(30);
|
||||
const savedGraphPanelPct = useRef(30);
|
||||
const resizing = useRef(false);
|
||||
|
||||
// Drag-to-resize the graph panel
|
||||
useEffect(() => {
|
||||
const onMouseMove = (e: MouseEvent) => {
|
||||
if (!resizing.current) return;
|
||||
const pct = (e.clientX / window.innerWidth) * 100;
|
||||
setGraphPanelPct(Math.max(15, Math.min(50, pct)));
|
||||
};
|
||||
const onMouseUp = () => {
|
||||
resizing.current = false;
|
||||
document.body.style.cursor = "";
|
||||
};
|
||||
window.addEventListener("mousemove", onMouseMove);
|
||||
window.addEventListener("mouseup", onMouseUp);
|
||||
return () => {
|
||||
window.removeEventListener("mousemove", onMouseMove);
|
||||
window.removeEventListener("mouseup", onMouseUp);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Shrink graph panel when node detail opens, restore when it closes
|
||||
const nodeIsSelected = selectedNode !== null;
|
||||
useEffect(() => {
|
||||
if (nodeIsSelected) {
|
||||
savedGraphPanelPct.current = graphPanelPct;
|
||||
setGraphPanelPct(prev => Math.min(prev, 30));
|
||||
} else {
|
||||
setGraphPanelPct(savedGraphPanelPct.current);
|
||||
}
|
||||
}, [nodeIsSelected]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
// Ref mirror of sessionsByAgent so SSE callback can read current graph
|
||||
// state without adding sessionsByAgent to its dependency array.
|
||||
@@ -571,6 +633,13 @@ export default function Workspace() {
|
||||
// it was created in (avoids stale-closure when phase change and message
|
||||
// events arrive in the same React batch).
|
||||
const queenPhaseRef = useRef<Record<string, string>>({});
|
||||
// Accumulated queen text across inner_turns within the same iteration.
|
||||
// Key: `${agentType}:${execution_id}:${iteration}`, value: { [inner_turn]: snapshot }.
|
||||
// This lets us merge all inner_turn text into one chat bubble per iteration.
|
||||
const queenIterTextRef = useRef<Record<string, Record<number, string>>>({});
|
||||
// Timestamp when designingDraft was set — used to enforce minimum spinner duration.
|
||||
const designingDraftSinceRef = useRef<Record<string, number>>({});
|
||||
const designingDraftTimerRef = useRef<Record<string, ReturnType<typeof setTimeout>>>({});
|
||||
|
||||
// Synchronous ref to suppress the queen's auto-intro SSE messages
|
||||
// after a cold-restore (where we already restored the conversation from disk).
|
||||
@@ -749,6 +818,8 @@ export default function Workspace() {
|
||||
}
|
||||
|
||||
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
|
||||
let restoredFlowchartMap: Record<string, string[]> | null = null;
|
||||
let restoredOriginalDraft: DraftGraphData | null = null;
|
||||
if (!liveSession) {
|
||||
// Fetch conversation history from disk BEFORE creating the new session.
|
||||
// SKIP if messages were already pre-populated by handleHistoryOpen.
|
||||
@@ -760,9 +831,22 @@ export default function Workspace() {
|
||||
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
|
||||
preRestoredMsgs.push(...restored.messages);
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not available — will start fresh
|
||||
}
|
||||
} else if (restoreFrom && alreadyHasMessages) {
|
||||
// Messages already cached in localStorage — still fetch events for
|
||||
// non-message state (phase, flowchart) that isn't cached.
|
||||
try {
|
||||
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not critical — UI will still show cached messages
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress the queen's intro cycle whenever we are about to restore a
|
||||
@@ -785,7 +869,7 @@ export default function Workspace() {
|
||||
}));
|
||||
}
|
||||
restoredMessageCount = preRestoredMsgs.length;
|
||||
} else if (restoreFrom && activeId) {
|
||||
} else if (restoreFrom && activeId && !alreadyHasMessages) {
|
||||
// We had a stored session but no messages on disk — wipe stale localStorage cache
|
||||
setSessionsByAgent(prev => ({
|
||||
...prev,
|
||||
@@ -839,6 +923,9 @@ export default function Workspace() {
|
||||
queenReady: true,
|
||||
queenPhase: qPhase,
|
||||
queenBuilding: qPhase === "building",
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
@@ -913,6 +1000,8 @@ export default function Workspace() {
|
||||
|
||||
// Track the last queen phase seen in the event log for cold restore
|
||||
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
|
||||
let restoredFlowchartMap: Record<string, string[]> | null = null;
|
||||
let restoredOriginalDraft: DraftGraphData | null = null;
|
||||
|
||||
if (!liveSession) {
|
||||
// Reconnect failed — clear stale cached messages from localStorage restore.
|
||||
@@ -940,6 +1029,19 @@ export default function Workspace() {
|
||||
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
|
||||
preQueenMsgs = restored.messages;
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} else if (coldRestoreId && alreadyHasMessages) {
|
||||
// Messages already cached — still fetch events for non-message state (phase, flowchart)
|
||||
try {
|
||||
const displayNameTemp = formatAgentDisplayName(agentPath);
|
||||
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
|
||||
restoredPhase = restored.restoredPhase;
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
} catch {
|
||||
// Not critical — UI will still show cached messages
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress intro whenever we are about to restore a previous conversation.
|
||||
@@ -1020,6 +1122,9 @@ export default function Workspace() {
|
||||
displayName,
|
||||
queenPhase: initialPhase,
|
||||
queenBuilding: initialPhase === "building",
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
|
||||
// Update the session label + backendSessionId. Also set historySourceId
|
||||
@@ -1057,6 +1162,11 @@ export default function Workspace() {
|
||||
if (historyId && !coldRestoreId) {
|
||||
const restored = await restoreSessionMessages(historyId, agentType, displayName);
|
||||
restoredMsgs.push(...restored.messages);
|
||||
// Use flowchart from event log if not already set
|
||||
if (restored.flowchartMap && !restoredFlowchartMap) {
|
||||
restoredFlowchartMap = restored.flowchartMap;
|
||||
restoredOriginalDraft = restored.originalDraft;
|
||||
}
|
||||
|
||||
// Check worker status (needed for isWorkerRunning flag)
|
||||
try {
|
||||
@@ -1099,6 +1209,9 @@ export default function Workspace() {
|
||||
loading: false,
|
||||
queenReady: !!(isResumedSession || hasRestoredContent),
|
||||
...(isWorkerRunning ? { workerRunState: "running" } : {}),
|
||||
// Restore flowchart overlay from persisted events
|
||||
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
|
||||
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
@@ -1180,8 +1293,8 @@ export default function Workspace() {
|
||||
graphsApi.draftGraph(state.sessionId).then(({ draft }) => {
|
||||
if (draft) updateAgentState(agentType, { draftGraph: draft });
|
||||
}).catch(() => {});
|
||||
} else {
|
||||
// Fetch flowchart map for non-planning phases (staging, running, building)
|
||||
} else if (state.queenPhase !== "building") {
|
||||
// Fetch flowchart map for non-building phases (staging, running)
|
||||
if (state.originalDraft) continue; // already have it
|
||||
if (fetchedFlowchartMapSessionsRef.current.has(state.sessionId)) continue;
|
||||
fetchedFlowchartMapSessionsRef.current.add(state.sessionId);
|
||||
@@ -1190,6 +1303,7 @@ export default function Workspace() {
|
||||
updateAgentState(agentType, {
|
||||
flowchartMap: map,
|
||||
originalDraft: original_draft,
|
||||
draftGraph: null,
|
||||
});
|
||||
}
|
||||
}).catch(() => {});
|
||||
@@ -1214,12 +1328,28 @@ export default function Workspace() {
|
||||
|
||||
const fireMap = new Map<string, number>();
|
||||
const taskMap = new Map<string, string>();
|
||||
const labelMap = new Map<string, string>();
|
||||
const targetMap = new Map<string, string>();
|
||||
for (const ep of triggerEps) {
|
||||
const nodeId = `__trigger_${ep.id}`;
|
||||
if (ep.next_fire_in != null) {
|
||||
fireMap.set(`__trigger_${ep.id}`, ep.next_fire_in);
|
||||
fireMap.set(nodeId, ep.next_fire_in);
|
||||
}
|
||||
if (ep.task != null) {
|
||||
taskMap.set(`__trigger_${ep.id}`, ep.task);
|
||||
taskMap.set(nodeId, ep.task);
|
||||
}
|
||||
const cron = ep.trigger_config?.cron as string | undefined;
|
||||
const interval = ep.trigger_config?.interval_minutes as number | undefined;
|
||||
const epLabel = cron
|
||||
? cronToLabel(cron)
|
||||
: interval
|
||||
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
|
||||
: ep.name || undefined;
|
||||
if (epLabel) {
|
||||
labelMap.set(nodeId, epLabel);
|
||||
}
|
||||
if (ep.entry_node) {
|
||||
targetMap.set(nodeId, ep.entry_node);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1228,14 +1358,18 @@ export default function Workspace() {
|
||||
if (!ss?.length) return prev;
|
||||
const existingIds = new Set(ss[0].graphNodes.map(n => n.id));
|
||||
|
||||
// Update existing trigger nodes
|
||||
// Update existing trigger nodes (countdown, task, label, target)
|
||||
let updated = ss[0].graphNodes.map((n) => {
|
||||
if (n.nodeType !== "trigger") return n;
|
||||
const nfi = fireMap.get(n.id);
|
||||
const task = taskMap.get(n.id);
|
||||
if (nfi == null && task == null) return n;
|
||||
const label = labelMap.get(n.id);
|
||||
const target = targetMap.get(n.id);
|
||||
if (nfi == null && task == null && !label && !target) return n;
|
||||
return {
|
||||
...n,
|
||||
...(label && label !== n.label ? { label } : {}),
|
||||
...(target ? { next: [target] } : {}),
|
||||
triggerConfig: {
|
||||
...n.triggerConfig,
|
||||
...(nfi != null ? { next_fire_in: nfi } : {}),
|
||||
@@ -1245,14 +1379,15 @@ export default function Workspace() {
|
||||
});
|
||||
|
||||
// Discover new triggers not yet in the graph
|
||||
const entryNode = ss[0].graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const fallbackEntry = ss[0].graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const newNodes: GraphNode[] = [];
|
||||
for (const ep of triggerEps) {
|
||||
const nodeId = `__trigger_${ep.id}`;
|
||||
if (existingIds.has(nodeId)) continue;
|
||||
const target = ep.entry_node || fallbackEntry;
|
||||
newNodes.push({
|
||||
id: nodeId,
|
||||
label: ep.name || ep.id,
|
||||
label: labelMap.get(nodeId) || ep.name || ep.id,
|
||||
status: "pending",
|
||||
nodeType: "trigger",
|
||||
triggerType: ep.trigger_type,
|
||||
@@ -1261,7 +1396,7 @@ export default function Workspace() {
|
||||
...(ep.next_fire_in != null ? { next_fire_in: ep.next_fire_in } : {}),
|
||||
...(ep.task ? { task: ep.task } : {}),
|
||||
},
|
||||
...(entryNode ? { next: [entryNode] } : {}),
|
||||
...(target ? { next: [target] } : {}),
|
||||
});
|
||||
}
|
||||
if (newNodes.length > 0) {
|
||||
@@ -1578,6 +1713,31 @@ export default function Workspace() {
|
||||
const chatMsg = sseEventToChatMessage(event, agentType, displayName, currentTurn);
|
||||
if (isQueen) console.log('[QUEEN] chatMsg:', chatMsg?.id, chatMsg?.content?.slice(0, 50), 'turn:', currentTurn);
|
||||
if (chatMsg && !suppressQueenMessages) {
|
||||
// Queen emits multiple client_output_delta / llm_text_delta snapshots
|
||||
// across iterations and inner tool-loop turns. Merge all inner_turns
|
||||
// within the same iteration into ONE bubble so the queen's multi-step
|
||||
// tool loop (text → tool → text → tool → text) appears as one cohesive
|
||||
// message rather than many small fragments.
|
||||
if (isQueen && (event.type === "client_output_delta" || event.type === "llm_text_delta") && event.execution_id) {
|
||||
const iter = event.data?.iteration ?? 0;
|
||||
const inner = (event.data?.inner_turn as number) ?? 0;
|
||||
const iterKey = `${agentType}:${event.execution_id}:${iter}`;
|
||||
|
||||
// Store the latest snapshot for this inner_turn
|
||||
if (!queenIterTextRef.current[iterKey]) {
|
||||
queenIterTextRef.current[iterKey] = {};
|
||||
}
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
queenIterTextRef.current[iterKey][inner] = snapshot;
|
||||
|
||||
// Concatenate all inner_turn snapshots in order
|
||||
const parts = queenIterTextRef.current[iterKey];
|
||||
const sortedInners = Object.keys(parts).map(Number).sort((a, b) => a - b);
|
||||
chatMsg.content = sortedInners.map(k => parts[k]).join("\n");
|
||||
|
||||
// Single ID per iteration — no inner_turn in the ID
|
||||
chatMsg.id = `queen-stream-${event.execution_id}-${iter}`;
|
||||
}
|
||||
if (isQueen) {
|
||||
chatMsg.role = role;
|
||||
chatMsg.phase = queenPhaseRef.current[agentType] as ChatMessage["phase"];
|
||||
@@ -1823,6 +1983,15 @@ export default function Workspace() {
|
||||
const toolName = (event.data?.tool_name as string) || "unknown";
|
||||
const toolUseId = (event.data?.tool_use_id as string) || "";
|
||||
|
||||
// Flag when the queen starts designing/updating the flowchart
|
||||
if (isQueen && toolName === "save_agent_draft") {
|
||||
designingDraftSinceRef.current[agentType] = Date.now();
|
||||
// Clear any pending delayed-clear timer from a previous call
|
||||
const prev = designingDraftTimerRef.current[agentType];
|
||||
if (prev) clearTimeout(prev);
|
||||
updateAgentState(agentType, { designingDraft: true });
|
||||
}
|
||||
|
||||
// Track active (in-flight) tools and upsert activity row into chat
|
||||
const sid = event.stream_id;
|
||||
setAgentStates(prev => {
|
||||
@@ -1842,6 +2011,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -1913,6 +2084,8 @@ export default function Workspace() {
|
||||
role,
|
||||
thread: agentType,
|
||||
createdAt: eventCreatedAt,
|
||||
nodeId: event.node_id || undefined,
|
||||
executionId: event.execution_id || undefined,
|
||||
});
|
||||
return {
|
||||
...prev,
|
||||
@@ -1989,6 +2162,29 @@ export default function Workspace() {
|
||||
}
|
||||
break;
|
||||
|
||||
case "context_usage_updated": {
|
||||
const streamKey = isQueen ? "__queen__" : (event.node_id || streamId);
|
||||
const usagePct = (event.data?.usage_pct as number) ?? 0;
|
||||
const messageCount = (event.data?.message_count as number) ?? 0;
|
||||
const estimatedTokens = (event.data?.estimated_tokens as number) ?? 0;
|
||||
const maxTokens = (event.data?.max_context_tokens as number) ?? 0;
|
||||
setAgentStates(prev => {
|
||||
const state = prev[agentType];
|
||||
if (!state) return prev;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: {
|
||||
...state,
|
||||
contextUsage: {
|
||||
...state.contextUsage,
|
||||
[streamKey]: { usagePct, messageCount, estimatedTokens, maxTokens },
|
||||
},
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "node_action_plan":
|
||||
if (!isQueen && event.node_id) {
|
||||
const plan = (event.data?.plan as string) || "";
|
||||
@@ -2030,20 +2226,19 @@ export default function Workspace() {
|
||||
queenBuilding: newPhase === "building",
|
||||
// Sync workerRunState so the RunButton reflects the phase
|
||||
workerRunState: newPhase === "running" ? "running" : "idle",
|
||||
// Clear draft graph once we leave planning/building; keep it during
|
||||
// building so the DraftGraph can show a loading overlay.
|
||||
...(newPhase !== "planning" && newPhase !== "building"
|
||||
? { draftGraph: null }
|
||||
: newPhase === "planning"
|
||||
? { originalDraft: null, flowchartMap: null }
|
||||
: {}),
|
||||
// Clear originalDraft/flowchartMap when re-entering planning.
|
||||
// draftGraph is cleared later when originalDraft arrives, so the
|
||||
// entrance animation has data to render during the handoff.
|
||||
...(newPhase === "planning"
|
||||
? { originalDraft: null, flowchartMap: null }
|
||||
: {}),
|
||||
// Store agent path for credential queries
|
||||
...(eventAgentPath ? { agentPath: eventAgentPath } : {}),
|
||||
});
|
||||
{
|
||||
const sid = agentStates[agentType]?.sessionId;
|
||||
if (sid) {
|
||||
if (newPhase !== "planning") {
|
||||
if (newPhase !== "planning" && newPhase !== "building") {
|
||||
fetchedDraftSessionsRef.current.delete(sid);
|
||||
fetchedFlowchartMapSessionsRef.current.delete(sid);
|
||||
// Fetch the flowchart map (original draft + dissolution mapping)
|
||||
@@ -2053,7 +2248,8 @@ export default function Workspace() {
|
||||
originalDraft: original_draft,
|
||||
});
|
||||
}).catch(() => {});
|
||||
} else {
|
||||
} else if (newPhase === "planning") {
|
||||
// Only clear dedup sets when re-entering planning (not building)
|
||||
fetchedDraftSessionsRef.current.delete(sid);
|
||||
fetchedFlowchartMapSessionsRef.current.delete(sid);
|
||||
}
|
||||
@@ -2066,7 +2262,28 @@ export default function Workspace() {
|
||||
// The draft dict is published directly as event.data (not nested under a key)
|
||||
const draft = event.data as unknown as DraftGraphData | undefined;
|
||||
if (draft?.nodes) {
|
||||
updateAgentState(agentType, { draftGraph: draft });
|
||||
// Ensure the "Designing flowchart…" spinner stays visible for a
|
||||
// minimum duration so users see feedback before the draft appears.
|
||||
const MIN_SPINNER_MS = 600;
|
||||
const since = designingDraftSinceRef.current[agentType] || 0;
|
||||
const elapsed = Date.now() - since;
|
||||
const remaining = Math.max(0, MIN_SPINNER_MS - elapsed);
|
||||
|
||||
const applyDraft = () => {
|
||||
delete designingDraftTimerRef.current[agentType];
|
||||
updateAgentState(agentType, { draftGraph: draft, designingDraft: false });
|
||||
};
|
||||
|
||||
if (remaining > 0 && since > 0) {
|
||||
// Update draftGraph now (so data is ready) but keep spinner visible
|
||||
updateAgentState(agentType, { draftGraph: draft });
|
||||
designingDraftTimerRef.current[agentType] = setTimeout(() => {
|
||||
updateAgentState(agentType, { designingDraft: false });
|
||||
delete designingDraftTimerRef.current[agentType];
|
||||
}, remaining);
|
||||
} else {
|
||||
applyDraft();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -2077,6 +2294,7 @@ export default function Workspace() {
|
||||
updateAgentState(agentType, {
|
||||
flowchartMap: mapData.map ?? null,
|
||||
originalDraft: mapData.original_draft ?? null,
|
||||
draftGraph: null,
|
||||
});
|
||||
}
|
||||
break;
|
||||
@@ -2150,10 +2368,18 @@ export default function Workspace() {
|
||||
// Synthesize new trigger node at the front of the graph
|
||||
const triggerType = (event.data?.trigger_type as string) || "timer";
|
||||
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
|
||||
const entryNode = s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const entryNode = (event.data?.entry_node as string) || s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const triggerName = (event.data?.name as string) || triggerId;
|
||||
const _cron = triggerConfig.cron as string | undefined;
|
||||
const _interval = triggerConfig.interval_minutes as number | undefined;
|
||||
const computedLabel = _cron
|
||||
? cronToLabel(_cron)
|
||||
: _interval
|
||||
? `Every ${_interval >= 60 ? `${_interval / 60}h` : `${_interval}m`}`
|
||||
: triggerName;
|
||||
const newNode: GraphNode = {
|
||||
id: nodeId,
|
||||
label: triggerId,
|
||||
label: computedLabel,
|
||||
status: "running",
|
||||
nodeType: "trigger",
|
||||
triggerType,
|
||||
@@ -2218,10 +2444,18 @@ export default function Workspace() {
|
||||
if (s.graphNodes.some(n => n.id === nodeId)) return s;
|
||||
const triggerType = (event.data?.trigger_type as string) || "timer";
|
||||
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
|
||||
const entryNode = s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const entryNode = (event.data?.entry_node as string) || s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
|
||||
const triggerName = (event.data?.name as string) || triggerId;
|
||||
const _cron2 = triggerConfig.cron as string | undefined;
|
||||
const _interval2 = triggerConfig.interval_minutes as number | undefined;
|
||||
const computedLabel2 = _cron2
|
||||
? cronToLabel(_cron2)
|
||||
: _interval2
|
||||
? `Every ${_interval2 >= 60 ? `${_interval2 / 60}h` : `${_interval2}m`}`
|
||||
: triggerName;
|
||||
const newNode: GraphNode = {
|
||||
id: nodeId,
|
||||
label: triggerId,
|
||||
label: computedLabel2,
|
||||
status: "pending",
|
||||
nodeType: "trigger",
|
||||
triggerType,
|
||||
@@ -2236,6 +2470,43 @@ export default function Workspace() {
|
||||
break;
|
||||
}
|
||||
|
||||
case "trigger_updated": {
|
||||
const triggerId = event.data?.trigger_id as string;
|
||||
if (triggerId) {
|
||||
const nodeId = `__trigger_${triggerId}`;
|
||||
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
|
||||
const cron = triggerConfig.cron as string | undefined;
|
||||
const interval = triggerConfig.interval_minutes as number | undefined;
|
||||
const newLabel = cron
|
||||
? cronToLabel(cron)
|
||||
: interval
|
||||
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
|
||||
: undefined;
|
||||
setSessionsByAgent(prev => {
|
||||
const sessions = prev[agentType] || [];
|
||||
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: sessions.map(s => {
|
||||
if (s.id !== activeId) return s;
|
||||
return {
|
||||
...s,
|
||||
graphNodes: s.graphNodes.map(n => {
|
||||
if (n.id !== nodeId) return n;
|
||||
return {
|
||||
...n,
|
||||
...(newLabel ? { label: newLabel } : {}),
|
||||
triggerConfig: { ...n.triggerConfig, ...triggerConfig },
|
||||
};
|
||||
}),
|
||||
};
|
||||
}),
|
||||
};
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "trigger_removed": {
|
||||
const triggerId = event.data?.trigger_id as string;
|
||||
if (triggerId) {
|
||||
@@ -2289,14 +2560,43 @@ export default function Workspace() {
|
||||
const liveSelectedNode = selectedNode && currentGraph.nodes.find(n => n.id === selectedNode.id);
|
||||
const resolvedSelectedNode = liveSelectedNode || selectedNode;
|
||||
|
||||
// Sync trigger task draft when selected trigger node changes
|
||||
// Sync trigger drafts when selected trigger node changes
|
||||
useEffect(() => {
|
||||
if (resolvedSelectedNode?.nodeType === "trigger") {
|
||||
const tc = resolvedSelectedNode.triggerConfig as Record<string, unknown> | undefined;
|
||||
setTriggerTaskDraft((tc?.task as string) || "");
|
||||
setTriggerCronDraft((tc?.cron as string) || "");
|
||||
}
|
||||
}, [resolvedSelectedNode?.id]);
|
||||
|
||||
const patchTriggerNode = useCallback((agentType: string, triggerNodeId: string, patch: { task?: string; trigger_config?: Record<string, unknown>; label?: string }) => {
|
||||
setSessionsByAgent(prev => {
|
||||
const sessions = prev[agentType] || [];
|
||||
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: sessions.map(s => {
|
||||
if (s.id !== activeId) return s;
|
||||
return {
|
||||
...s,
|
||||
graphNodes: s.graphNodes.map(n => {
|
||||
if (n.id !== triggerNodeId) return n;
|
||||
return {
|
||||
...n,
|
||||
...(patch.label !== undefined ? { label: patch.label } : {}),
|
||||
triggerConfig: {
|
||||
...n.triggerConfig,
|
||||
...(patch.trigger_config || {}),
|
||||
...(patch.task !== undefined ? { task: patch.task } : {}),
|
||||
},
|
||||
};
|
||||
}),
|
||||
};
|
||||
}),
|
||||
};
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Build a flat list of all agent-type tabs for the tab bar
|
||||
const agentTabs = Object.entries(sessionsByAgent)
|
||||
.filter(([, sessions]) => sessions.length > 0)
|
||||
@@ -2764,7 +3064,6 @@ export default function Workspace() {
|
||||
|
||||
const activeWorkerLabel = activeAgentState?.displayName || formatAgentDisplayName(baseAgentType(activeWorker));
|
||||
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-screen bg-background overflow-hidden">
|
||||
<TopBar
|
||||
@@ -2812,38 +3111,40 @@ export default function Workspace() {
|
||||
{/* Main content area */}
|
||||
<div className="flex flex-1 min-h-0">
|
||||
|
||||
{/* ── Pipeline graph + chat ──────────────────────────────────── */}
|
||||
<div className={`${((activeAgentState?.queenPhase === "planning" || activeAgentState?.queenPhase === "building") && activeAgentState?.draftGraph) || activeAgentState?.originalDraft ? "w-[500px] min-w-[400px]" : "w-[300px] min-w-[240px]"} bg-card/30 flex flex-col border-r border-border/30 transition-[width] duration-200`}>
|
||||
{/* ── Draft flowchart + chat ─────────────────────────────────── */}
|
||||
<div
|
||||
className="bg-card/30 flex flex-col border-r border-border/30 relative"
|
||||
style={{ width: `${graphPanelPct}%`, minWidth: 240, flexShrink: 0 }}
|
||||
>
|
||||
<div className="flex-1 min-h-0">
|
||||
{(activeAgentState?.queenPhase === "planning" || activeAgentState?.queenPhase === "building") && activeAgentState?.draftGraph ? (
|
||||
<DraftGraph draft={activeAgentState.draftGraph} building={activeAgentState?.queenBuilding} onRun={handleRun} onPause={handlePause} runState={activeAgentState?.workerRunState ?? "idle"} />
|
||||
) : activeAgentState?.originalDraft ? (
|
||||
<DraftGraph
|
||||
draft={activeAgentState.originalDraft}
|
||||
building={activeAgentState?.queenBuilding}
|
||||
onRun={handleRun}
|
||||
onPause={handlePause}
|
||||
runState={activeAgentState?.workerRunState ?? "idle"}
|
||||
flowchartMap={activeAgentState.flowchartMap ?? undefined}
|
||||
runtimeNodes={currentGraph.nodes}
|
||||
onRuntimeNodeClick={(runtimeNodeId) => {
|
||||
const node = currentGraph.nodes.find(n => n.id === runtimeNodeId);
|
||||
if (node) setSelectedNode(prev => prev?.id === node.id ? null : node);
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<AgentGraph
|
||||
nodes={currentGraph.nodes}
|
||||
title={currentGraph.title}
|
||||
onNodeClick={(node) => setSelectedNode(prev => prev?.id === node.id ? null : node)}
|
||||
onRun={handleRun}
|
||||
onPause={handlePause}
|
||||
runState={activeAgentState?.workerRunState ?? "idle"}
|
||||
building={activeAgentState?.queenBuilding ?? false}
|
||||
queenPhase={activeAgentState?.queenPhase ?? "building"}
|
||||
/>
|
||||
)}
|
||||
<DraftGraph
|
||||
key={activeWorker}
|
||||
draft={activeAgentState?.originalDraft ?? activeAgentState?.draftGraph ?? null}
|
||||
originalDraft={activeAgentState?.originalDraft ?? null}
|
||||
loadingMessage={
|
||||
activeAgentState?.designingDraft
|
||||
? "Designing flowchart…"
|
||||
: !activeAgentState?.originalDraft && !activeAgentState?.draftGraph && activeAgentState?.queenPhase !== "planning"
|
||||
? "Loading flowchart…"
|
||||
: null
|
||||
}
|
||||
building={activeAgentState?.queenBuilding}
|
||||
onRun={handleRun}
|
||||
onPause={handlePause}
|
||||
runState={activeAgentState?.workerRunState ?? "idle"}
|
||||
flowchartMap={activeAgentState?.flowchartMap ?? undefined}
|
||||
runtimeNodes={currentGraph.nodes}
|
||||
onRuntimeNodeClick={(runtimeNodeId) => {
|
||||
const node = currentGraph.nodes.find(n => n.id === runtimeNodeId);
|
||||
if (node) setSelectedNode(prev => prev?.id === node.id ? null : node);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/* Resize handle */}
|
||||
<div
|
||||
className="absolute top-0 right-0 w-1 h-full cursor-col-resize hover:bg-primary/30 active:bg-primary/40 transition-colors z-10"
|
||||
onMouseDown={() => { resizing.current = true; document.body.style.cursor = "col-resize"; }}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-1 min-w-0 flex">
|
||||
<div className="flex-1 min-w-0 relative">
|
||||
@@ -2922,6 +3223,7 @@ export default function Workspace() {
|
||||
}
|
||||
onMultiQuestionSubmit={handleMultiQuestionAnswer}
|
||||
onQuestionDismiss={handleQuestionDismiss}
|
||||
contextUsage={activeAgentState?.contextUsage}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -2964,18 +3266,64 @@ export default function Workspace() {
|
||||
const interval = tc?.interval_minutes as number | undefined;
|
||||
const eventTypes = tc?.event_types as string[] | undefined;
|
||||
const scheduleLabel = cron
|
||||
? `cron: ${cron}`
|
||||
? cronToLabel(cron)
|
||||
: interval
|
||||
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
|
||||
: eventTypes?.length
|
||||
? eventTypes.join(", ")
|
||||
: null;
|
||||
return scheduleLabel ? (
|
||||
const canEditCron = resolvedSelectedNode.triggerType === "timer";
|
||||
const cronChanged = canEditCron && triggerCronDraft.trim() !== (cron || "");
|
||||
return scheduleLabel || canEditCron ? (
|
||||
<div>
|
||||
<p className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider mb-1.5">Schedule</p>
|
||||
<p className="text-xs text-foreground/80 font-mono bg-muted/30 rounded-lg px-3 py-2 border border-border/20">
|
||||
{scheduleLabel}
|
||||
</p>
|
||||
{scheduleLabel && (
|
||||
<p className="text-xs text-foreground/80 font-mono bg-muted/30 rounded-lg px-3 py-2 border border-border/20">
|
||||
{scheduleLabel}
|
||||
</p>
|
||||
)}
|
||||
{canEditCron && (
|
||||
<>
|
||||
<input
|
||||
value={triggerCronDraft}
|
||||
onChange={(e) => setTriggerCronDraft(e.target.value)}
|
||||
placeholder="0 5 * * *"
|
||||
className="mt-1.5 w-full text-xs text-foreground/80 bg-muted/30 rounded-lg px-3 py-2 border border-border/20 font-mono focus:outline-none focus:border-primary/40"
|
||||
/>
|
||||
<p className="text-[10px] text-muted-foreground/60 mt-1">
|
||||
Edit the cron expression for this timer trigger.
|
||||
</p>
|
||||
{(cronChanged || triggerCronSaved) && (
|
||||
<button
|
||||
disabled={triggerScheduleSaving || !cronChanged}
|
||||
onClick={async () => {
|
||||
const sessionId = activeAgentState?.sessionId;
|
||||
const triggerId = resolvedSelectedNode.id.replace("__trigger_", "");
|
||||
const nextCron = triggerCronDraft.trim();
|
||||
if (!sessionId || !nextCron) return;
|
||||
const nextTriggerConfig: Record<string, unknown> = { cron: nextCron };
|
||||
setTriggerScheduleSaving(true);
|
||||
try {
|
||||
await sessionsApi.updateTrigger(sessionId, triggerId, {
|
||||
trigger_config: nextTriggerConfig,
|
||||
});
|
||||
patchTriggerNode(activeWorker, resolvedSelectedNode.id, {
|
||||
trigger_config: nextTriggerConfig,
|
||||
label: cronToLabel(nextCron),
|
||||
});
|
||||
setTriggerCronSaved(true);
|
||||
setTimeout(() => setTriggerCronSaved(false), 2000);
|
||||
} finally {
|
||||
setTriggerScheduleSaving(false);
|
||||
}
|
||||
}}
|
||||
className="mt-1.5 w-full text-[11px] px-3 py-1.5 rounded-lg border border-primary/30 text-primary hover:bg-primary/10 transition-colors disabled:opacity-50"
|
||||
>
|
||||
{triggerScheduleSaving ? "Saving..." : triggerCronSaved ? "Saved" : "Save Cron"}
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
) : null;
|
||||
})()}
|
||||
@@ -3002,24 +3350,27 @@ export default function Workspace() {
|
||||
{(() => {
|
||||
const currentTask = (resolvedSelectedNode.triggerConfig as Record<string, unknown> | undefined)?.task as string || "";
|
||||
const hasChanged = triggerTaskDraft !== currentTask;
|
||||
if (!hasChanged) return null;
|
||||
if (!hasChanged && !triggerTaskSaved) return null;
|
||||
return (
|
||||
<button
|
||||
disabled={triggerTaskSaving}
|
||||
disabled={triggerTaskSaving || !hasChanged}
|
||||
onClick={async () => {
|
||||
const sessionId = activeAgentState?.sessionId;
|
||||
const triggerId = resolvedSelectedNode.id.replace("__trigger_", "");
|
||||
if (!sessionId) return;
|
||||
setTriggerTaskSaving(true);
|
||||
try {
|
||||
await sessionsApi.updateTriggerTask(sessionId, triggerId, triggerTaskDraft);
|
||||
await sessionsApi.updateTrigger(sessionId, triggerId, { task: triggerTaskDraft });
|
||||
patchTriggerNode(activeWorker, resolvedSelectedNode.id, { task: triggerTaskDraft });
|
||||
setTriggerTaskSaved(true);
|
||||
setTimeout(() => setTriggerTaskSaved(false), 2000);
|
||||
} finally {
|
||||
setTriggerTaskSaving(false);
|
||||
}
|
||||
}}
|
||||
className="mt-1.5 w-full text-[11px] px-3 py-1.5 rounded-lg border border-primary/30 text-primary hover:bg-primary/10 transition-colors disabled:opacity-50"
|
||||
>
|
||||
{triggerTaskSaving ? "Saving..." : "Save Task"}
|
||||
{triggerTaskSaving ? "Saving..." : triggerTaskSaved ? "Saved" : "Save Task"}
|
||||
</button>
|
||||
);
|
||||
})()}
|
||||
@@ -3076,6 +3427,7 @@ export default function Workspace() {
|
||||
workerSessionId={null}
|
||||
nodeLogs={activeAgentState?.nodeLogs[resolvedSelectedNode.id] || []}
|
||||
actionPlan={activeAgentState?.nodeActionPlans[resolvedSelectedNode.id]}
|
||||
contextUsage={activeAgentState?.contextUsage[resolvedSelectedNode.id]}
|
||||
onClose={() => setSelectedNode(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
+6
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "framework"
|
||||
version = "0.5.1"
|
||||
version = "0.7.1"
|
||||
description = "Goal-driven agent runtime with Builder-friendly observability"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
@@ -11,6 +11,7 @@ dependencies = [
|
||||
"litellm>=1.81.0",
|
||||
"mcp>=1.0.0",
|
||||
"fastmcp>=2.0.0",
|
||||
"croniter>=1.4.0",
|
||||
"tools",
|
||||
]
|
||||
|
||||
@@ -61,6 +62,10 @@ lint.isort.section-order = [
|
||||
"first-party",
|
||||
"local-folder",
|
||||
]
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning:litellm.*"
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
||||
@@ -33,6 +33,7 @@ API_KEY_PROVIDERS = [
|
||||
("TOGETHER_API_KEY", "Together AI", "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo"),
|
||||
("DEEPSEEK_API_KEY", "DeepSeek", "deepseek-chat"),
|
||||
("MINIMAX_API_KEY", "MiniMax", "MiniMax-M2.5"),
|
||||
("HIVE_API_KEY", "Hive LLM", "hive/queen"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_check_llm_key_module():
|
||||
module_path = Path(__file__).resolve().parents[2] / "scripts" / "check_llm_key.py"
|
||||
spec = importlib.util.spec_from_file_location("check_llm_key_script", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _run_openrouter_check(monkeypatch, status_code: int):
|
||||
module = _load_check_llm_key_module()
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, code):
|
||||
self.status_code = code
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout):
|
||||
calls["timeout"] = timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, endpoint, headers):
|
||||
calls["endpoint"] = endpoint
|
||||
calls["headers"] = headers
|
||||
return FakeResponse(status_code)
|
||||
|
||||
monkeypatch.setattr(module.httpx, "Client", FakeClient)
|
||||
result = module.check_openrouter("test-key")
|
||||
return result, calls
|
||||
|
||||
|
||||
def _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
status_code: int,
|
||||
payload: dict | None = None,
|
||||
model: str = "openai/gpt-4o-mini",
|
||||
):
|
||||
module = _load_check_llm_key_module()
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, code):
|
||||
self.status_code = code
|
||||
self._payload = payload
|
||||
self.text = ""
|
||||
|
||||
def json(self):
|
||||
if self._payload is None:
|
||||
raise ValueError("no json")
|
||||
return self._payload
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout):
|
||||
calls["timeout"] = timeout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, endpoint, headers):
|
||||
calls["endpoint"] = endpoint
|
||||
calls["headers"] = headers
|
||||
return FakeResponse(status_code)
|
||||
|
||||
monkeypatch.setattr(module.httpx, "Client", FakeClient)
|
||||
result = module.check_openrouter_model("test-key", model)
|
||||
return result, calls
|
||||
|
||||
|
||||
def test_check_openrouter_200(monkeypatch):
|
||||
result, calls = _run_openrouter_check(monkeypatch, 200)
|
||||
assert result == {"valid": True, "message": "OpenRouter API key valid"}
|
||||
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models"
|
||||
assert calls["headers"] == {"Authorization": "Bearer test-key"}
|
||||
|
||||
|
||||
def test_check_openrouter_401(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 401)
|
||||
assert result == {"valid": False, "message": "Invalid OpenRouter API key"}
|
||||
|
||||
|
||||
def test_check_openrouter_403(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 403)
|
||||
assert result == {"valid": False, "message": "OpenRouter API key lacks permissions"}
|
||||
|
||||
|
||||
def test_check_openrouter_429(monkeypatch):
|
||||
result, _ = _run_openrouter_check(monkeypatch, 429)
|
||||
assert result == {"valid": True, "message": "OpenRouter API key valid"}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200(monkeypatch):
|
||||
result, calls = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "openai/gpt-4o-mini",
|
||||
"canonical_slug": "openai/gpt-4o-mini",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: openai/gpt-4o-mini",
|
||||
"model": "openai/gpt-4o-mini",
|
||||
}
|
||||
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models/user"
|
||||
assert calls["headers"] == {"Authorization": "Bearer test-key"}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_matches_canonical_slug(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "mistralai/mistral-small-4",
|
||||
"canonical_slug": "mistralai/mistral-small-2603",
|
||||
}
|
||||
]
|
||||
},
|
||||
model="mistralai/mistral-small-2603",
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: mistralai/mistral-small-2603",
|
||||
"model": "mistralai/mistral-small-2603",
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_sanitizes_pasted_unicode(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "z-ai/glm-5-turbo",
|
||||
"canonical_slug": "z-ai/glm-5-turbo",
|
||||
}
|
||||
]
|
||||
},
|
||||
model="openrouter/z-ai\u200b/glm\u20115\u2011turbo",
|
||||
)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model is available: z-ai/glm-5-turbo",
|
||||
"model": "z-ai/glm-5-turbo",
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_200_not_found_with_suggestions(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
200,
|
||||
{
|
||||
"data": [
|
||||
{"id": "z-ai/glm-5-turbo"},
|
||||
{"id": "z-ai/glm-4.6v"},
|
||||
]
|
||||
},
|
||||
model="z-ai/glm-5-turb",
|
||||
)
|
||||
assert result == {
|
||||
"valid": False,
|
||||
"message": (
|
||||
"OpenRouter model is not available for this key/settings: z-ai/glm-5-turb. "
|
||||
"Closest matches: z-ai/glm-5-turbo"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_404_with_error_message(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(
|
||||
monkeypatch,
|
||||
404,
|
||||
{"error": {"message": "No endpoints available for this model"}},
|
||||
)
|
||||
assert result == {
|
||||
"valid": False,
|
||||
"message": (
|
||||
"OpenRouter model is not available for this key/settings: openai/gpt-4o-mini. "
|
||||
"No endpoints available for this model"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_check_openrouter_model_429(monkeypatch):
|
||||
result, _ = _run_openrouter_model_check(monkeypatch, 429)
|
||||
assert result == {
|
||||
"valid": True,
|
||||
"message": "OpenRouter model check rate-limited; assuming model is reachable",
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from framework.config import get_hive_config
|
||||
from framework.config import get_api_base, get_hive_config, get_preferred_model
|
||||
|
||||
|
||||
class TestGetHiveConfig:
|
||||
@@ -21,3 +21,47 @@ class TestGetHiveConfig:
|
||||
assert result == {}
|
||||
assert "Failed to load Hive config" in caplog.text
|
||||
assert str(config_file) in caplog.text
|
||||
|
||||
|
||||
class TestOpenRouterConfig:
|
||||
"""OpenRouter config composition and fallback behavior."""
|
||||
|
||||
def test_get_preferred_model_for_openrouter(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
|
||||
|
||||
def test_get_preferred_model_normalizes_openrouter_prefixed_model(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"openrouter/x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
|
||||
|
||||
def test_get_api_base_falls_back_to_openrouter_default(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_get_api_base_keeps_explicit_openrouter_api_base(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta","api_base":"https://proxy.example/v1"}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == "https://proxy.example/v1"
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
from framework.credentials import key_storage
|
||||
from framework.credentials.validation import ensure_credential_key_env
|
||||
|
||||
|
||||
def _install_fake_aden_modules(monkeypatch, check_fn, credential_specs):
|
||||
shell_config_module = ModuleType("aden_tools.credentials.shell_config")
|
||||
shell_config_module.check_env_var_in_shell_config = check_fn
|
||||
|
||||
credentials_module = ModuleType("aden_tools.credentials")
|
||||
credentials_module.CREDENTIAL_SPECS = credential_specs
|
||||
|
||||
monkeypatch.setitem(sys.modules, "aden_tools.credentials.shell_config", shell_config_module)
|
||||
monkeypatch.setitem(sys.modules, "aden_tools.credentials", credentials_module)
|
||||
|
||||
|
||||
def test_bootstrap_loads_configured_llm_env_var_from_shell_config(monkeypatch):
|
||||
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
|
||||
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"framework.config.get_hive_config",
|
||||
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
|
||||
)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
|
||||
calls = []
|
||||
|
||||
def check_env(var_name):
|
||||
calls.append(var_name)
|
||||
if var_name == "OPENROUTER_API_KEY":
|
||||
return True, "or-key-123"
|
||||
return False, None
|
||||
|
||||
_install_fake_aden_modules(
|
||||
monkeypatch,
|
||||
check_env,
|
||||
{"anthropic": SimpleNamespace(env_var="ANTHROPIC_API_KEY")},
|
||||
)
|
||||
|
||||
ensure_credential_key_env()
|
||||
|
||||
assert os.environ.get("OPENROUTER_API_KEY") == "or-key-123"
|
||||
assert "OPENROUTER_API_KEY" in calls
|
||||
|
||||
|
||||
def test_bootstrap_does_not_override_existing_configured_llm_env_var(monkeypatch):
|
||||
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
|
||||
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"framework.config.get_hive_config",
|
||||
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
|
||||
)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "already-set")
|
||||
|
||||
calls = []
|
||||
|
||||
def check_env(var_name):
|
||||
calls.append(var_name)
|
||||
return True, "new-value-should-not-apply"
|
||||
|
||||
_install_fake_aden_modules(monkeypatch, check_env, {})
|
||||
|
||||
ensure_credential_key_env()
|
||||
|
||||
assert os.environ.get("OPENROUTER_API_KEY") == "already-set"
|
||||
assert "OPENROUTER_API_KEY" not in calls
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Tests for default skills — parsing, token budget, and configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import (
|
||||
SHARED_MEMORY_KEYS,
|
||||
SKILL_REGISTRY,
|
||||
DefaultSkillManager,
|
||||
)
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
_DEFAULT_SKILLS_DIR = (
|
||||
Path(__file__).resolve().parent.parent / "framework" / "skills" / "_default_skills"
|
||||
)
|
||||
|
||||
|
||||
class TestDefaultSkillFiles:
|
||||
"""Verify all 6 built-in SKILL.md files parse correctly."""
|
||||
|
||||
def test_all_six_skills_exist(self):
|
||||
assert len(SKILL_REGISTRY) == 6
|
||||
|
||||
@pytest.mark.parametrize("skill_name,dir_name", list(SKILL_REGISTRY.items()))
|
||||
def test_skill_parses(self, skill_name, dir_name):
|
||||
path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
assert path.is_file(), f"Missing SKILL.md at {path}"
|
||||
|
||||
parsed = parse_skill_md(path, source_scope="framework")
|
||||
assert parsed is not None, f"Failed to parse {path}"
|
||||
assert parsed.name == skill_name
|
||||
assert parsed.description
|
||||
assert parsed.body
|
||||
assert parsed.source_scope == "framework"
|
||||
|
||||
def test_combined_token_budget(self):
|
||||
"""All default skill bodies combined should be under 2000 tokens (~8000 chars)."""
|
||||
total_chars = 0
|
||||
for dir_name in SKILL_REGISTRY.values():
|
||||
path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
parsed = parse_skill_md(path, source_scope="framework")
|
||||
assert parsed is not None
|
||||
total_chars += len(parsed.body)
|
||||
|
||||
approx_tokens = total_chars // 4
|
||||
assert approx_tokens < 2000, (
|
||||
f"Combined default skill bodies are ~{approx_tokens} tokens "
|
||||
f"({total_chars} chars), exceeding the 2000 token budget"
|
||||
)
|
||||
|
||||
def test_shared_memory_keys_all_prefixed(self):
|
||||
"""All shared memory keys must start with underscore."""
|
||||
for key in SHARED_MEMORY_KEYS:
|
||||
assert key.startswith("_"), f"Shared memory key missing _ prefix: {key}"
|
||||
|
||||
|
||||
class TestDefaultSkillManager:
|
||||
def test_load_all_defaults(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
|
||||
assert len(manager.active_skill_names) == 6
|
||||
for name in SKILL_REGISTRY:
|
||||
assert name in manager.active_skill_names
|
||||
|
||||
def test_load_idempotent(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
first_skills = dict(manager.active_skills)
|
||||
manager.load()
|
||||
assert manager.active_skills == first_skills
|
||||
|
||||
def test_build_protocols_prompt(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
|
||||
assert prompt.startswith("## Operational Protocols")
|
||||
# Should contain content from each active skill
|
||||
for name in SKILL_REGISTRY:
|
||||
skill = manager.active_skills[name]
|
||||
# At least some of the body should appear
|
||||
assert skill.body[:20] in prompt
|
||||
|
||||
def test_protocols_prompt_empty_when_all_disabled(self):
|
||||
config = SkillsConfig(all_defaults_disabled=True)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
|
||||
assert manager.build_protocols_prompt() == ""
|
||||
assert manager.active_skill_names == []
|
||||
|
||||
def test_disable_single_skill(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.quality-monitor": {"enabled": False}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
|
||||
assert "hive.quality-monitor" not in manager.active_skill_names
|
||||
assert len(manager.active_skill_names) == 5
|
||||
|
||||
def test_disable_all_via_convention(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills={"_all": {"enabled": False}})
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
|
||||
assert manager.active_skill_names == []
|
||||
|
||||
def test_log_active_skills(self, caplog):
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="framework.skills.defaults"):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
manager.log_active_skills()
|
||||
|
||||
assert "Default skills active:" in caplog.text
|
||||
|
||||
def test_log_all_disabled(self, caplog):
|
||||
import logging
|
||||
|
||||
config = SkillsConfig(all_defaults_disabled=True)
|
||||
with caplog.at_level(logging.INFO, logger="framework.skills.defaults"):
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
manager.log_active_skills()
|
||||
|
||||
assert "all disabled" in caplog.text
|
||||
|
||||
|
||||
class TestSkillsConfig:
|
||||
def test_default_is_enabled(self):
|
||||
config = SkillsConfig()
|
||||
assert config.is_default_enabled("hive.note-taking") is True
|
||||
|
||||
def test_explicit_disable(self):
|
||||
config = SkillsConfig(
|
||||
default_skills={"hive.note-taking": DefaultSkillConfig(enabled=False)}
|
||||
)
|
||||
assert config.is_default_enabled("hive.note-taking") is False
|
||||
assert config.is_default_enabled("hive.batch-ledger") is True
|
||||
|
||||
def test_all_disabled_flag(self):
|
||||
config = SkillsConfig(all_defaults_disabled=True)
|
||||
assert config.is_default_enabled("hive.note-taking") is False
|
||||
assert config.is_default_enabled("anything") is False
|
||||
|
||||
def test_from_agent_vars_basic(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={
|
||||
"hive.note-taking": {"enabled": True},
|
||||
"hive.quality-monitor": {"enabled": False},
|
||||
},
|
||||
skills=["deep-research"],
|
||||
)
|
||||
assert config.is_default_enabled("hive.note-taking") is True
|
||||
assert config.is_default_enabled("hive.quality-monitor") is False
|
||||
assert config.skills == ["deep-research"]
|
||||
|
||||
def test_from_agent_vars_bool_shorthand(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills={"hive.note-taking": False})
|
||||
assert config.is_default_enabled("hive.note-taking") is False
|
||||
|
||||
def test_from_agent_vars_all_disabled(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills={"_all": {"enabled": False}})
|
||||
assert config.all_defaults_disabled is True
|
||||
|
||||
def test_get_default_overrides(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={
|
||||
"hive.batch-ledger": {"enabled": True, "checkpoint_every_n": 10},
|
||||
}
|
||||
)
|
||||
overrides = config.get_default_overrides("hive.batch-ledger")
|
||||
assert overrides == {"checkpoint_every_n": 10}
|
||||
|
||||
def test_get_default_overrides_empty(self):
|
||||
config = SkillsConfig()
|
||||
assert config.get_default_overrides("hive.note-taking") == {}
|
||||
|
||||
def test_from_agent_vars_none_inputs(self):
|
||||
config = SkillsConfig.from_agent_vars(default_skills=None, skills=None)
|
||||
assert config.skills == []
|
||||
assert config.default_skills == {}
|
||||
assert config.all_defaults_disabled is False
|
||||
@@ -1530,6 +1530,34 @@ class TestTransientErrorRetry:
|
||||
await node.execute(ctx)
|
||||
assert llm._call_index == 1 # only tried once
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_non_transient_error_does_not_crash(
|
||||
self, runtime, node_spec, memory
|
||||
):
|
||||
"""Client-facing non-transient errors should wait for input, not crash on token vars."""
|
||||
node_spec.output_keys = []
|
||||
node_spec.client_facing = True
|
||||
llm = ErrorThenSuccessLLM(
|
||||
error=ValueError("bad request: blocked by policy"),
|
||||
fail_count=100, # always fails
|
||||
success_scenario=text_scenario("unreachable"),
|
||||
)
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
config=LoopConfig(
|
||||
max_iterations=1,
|
||||
max_stream_retries=0,
|
||||
stream_retry_backoff_base=0.01,
|
||||
),
|
||||
)
|
||||
node._await_user_input = AsyncMock(return_value=None)
|
||||
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Max iterations" in (result.error or "")
|
||||
node._await_user_input.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_exhausts_retries(self, runtime, node_spec, memory):
|
||||
"""Transient errors that exhaust retries should raise."""
|
||||
|
||||
@@ -12,6 +12,7 @@ Covers:
|
||||
- Single-edge paths unaffected
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -77,6 +78,19 @@ class TimingNode(NodeProtocol):
|
||||
)
|
||||
|
||||
|
||||
class SlowNode(NodeProtocol):
|
||||
"""Sleeps before returning -- used for timeout testing."""
|
||||
|
||||
def __init__(self, delay: float = 10.0):
|
||||
self.delay = delay
|
||||
self.executed = False
|
||||
|
||||
async def execute(self, ctx: NodeContext) -> NodeResult:
|
||||
await asyncio.sleep(self.delay)
|
||||
self.executed = True
|
||||
return NodeResult(success=True, output={"result": "slow"}, tokens_used=1, latency_ms=1)
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
@@ -492,3 +506,186 @@ async def test_parallel_disabled_uses_sequential(runtime, goal):
|
||||
# Only one branch should have executed (sequential follows first edge)
|
||||
executed_count = sum([b1_impl.executed, b2_impl.executed])
|
||||
assert executed_count == 1
|
||||
|
||||
|
||||
# === 12. Branch timeout cancels slow branch ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_timeout_cancels_slow_branch(runtime, goal):
|
||||
"""A branch exceeding branch_timeout_seconds should be cancelled."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="slow", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="fast", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
config = ParallelExecutionConfig(branch_timeout_seconds=0.1, on_branch_failure="fail_all")
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
||||
)
|
||||
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||
executor.register_node("b1", SlowNode(delay=10.0))
|
||||
executor.register_node("b2", SuccessNode({"b2_out": "ok"}))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
# fail_all: one branch timed out → execution fails
|
||||
assert not result.success
|
||||
assert "failed" in result.error.lower()
|
||||
|
||||
|
||||
# === 13. Branch timeout with continue_others ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_timeout_with_continue_others(runtime, goal):
|
||||
"""continue_others should let fast branches finish even when one times out."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="slow", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="fast", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
config = ParallelExecutionConfig(
|
||||
branch_timeout_seconds=0.1, on_branch_failure="continue_others"
|
||||
)
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
||||
)
|
||||
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||
executor.register_node("b1", SlowNode(delay=10.0))
|
||||
b2_impl = SuccessNode({"b2_out": "ok"})
|
||||
executor.register_node("b2", b2_impl)
|
||||
|
||||
await executor.execute(graph, goal, {})
|
||||
|
||||
# continue_others tolerates the timeout
|
||||
assert b2_impl.executed
|
||||
|
||||
|
||||
# === 14. Branch timeout with fail_all (explicit) ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_timeout_with_fail_all(runtime, goal):
|
||||
"""fail_all should propagate timeout as execution failure."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="slow", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="also slow", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
config = ParallelExecutionConfig(branch_timeout_seconds=0.1, on_branch_failure="fail_all")
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
||||
)
|
||||
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||
executor.register_node("b1", SlowNode(delay=10.0))
|
||||
executor.register_node("b2", SlowNode(delay=10.0))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert not result.success
|
||||
|
||||
|
||||
# === 15. Memory conflict: last_wins ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_conflict_last_wins(runtime, goal):
|
||||
"""last_wins should allow both branches to write the same key without error."""
|
||||
# Use distinct output_keys in spec (to pass graph validation) but have
|
||||
# the node impl write a shared key at runtime — this is the scenario
|
||||
# memory_conflict_strategy is designed to handle.
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
config = ParallelExecutionConfig(memory_conflict_strategy="last_wins")
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
||||
)
|
||||
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||
# Both impls write "shared_key" — triggers conflict detection at runtime
|
||||
executor.register_node("b1", SuccessNode({"shared_key": "from_b1", "b1_out": "ok"}))
|
||||
executor.register_node("b2", SuccessNode({"shared_key": "from_b2", "b2_out": "ok"}))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
# The key should exist with one of the two values
|
||||
assert result.output.get("shared_key") in ("from_b1", "from_b2")
|
||||
|
||||
|
||||
# === 16. Memory conflict: first_wins ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_conflict_first_wins(runtime, goal):
|
||||
"""first_wins should keep the first branch's value and skip later writes."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
config = ParallelExecutionConfig(memory_conflict_strategy="first_wins")
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
||||
)
|
||||
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||
executor.register_node("b1", SuccessNode({"shared_key": "from_b1", "b1_out": "ok"}))
|
||||
executor.register_node("b2", SuccessNode({"shared_key": "from_b2", "b2_out": "ok"}))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert result.success
|
||||
|
||||
|
||||
# === 17. Memory conflict: error raises ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_conflict_error_raises(runtime, goal):
|
||||
"""error strategy should fail when two branches write the same key."""
|
||||
b1 = NodeSpec(
|
||||
id="b1", name="B1", description="b1", node_type="event_loop", output_keys=["b1_out"]
|
||||
)
|
||||
b2 = NodeSpec(
|
||||
id="b2", name="B2", description="b2", node_type="event_loop", output_keys=["b2_out"]
|
||||
)
|
||||
|
||||
graph = _make_fanout_graph([b1, b2])
|
||||
|
||||
config = ParallelExecutionConfig(memory_conflict_strategy="error")
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime, enable_parallel_execution=True, parallel_config=config
|
||||
)
|
||||
executor.register_node("source", SuccessNode({"data": "x"}))
|
||||
executor.register_node("b1", SuccessNode({"shared_key": "from_b1", "b1_out": "ok"}))
|
||||
executor.register_node("b2", SuccessNode({"shared_key": "from_b2", "b2_out": "ok"}))
|
||||
|
||||
result = await executor.execute(graph, goal, {})
|
||||
|
||||
assert not result.success
|
||||
# The conflict RuntimeError is caught inside execute_single_branch,
|
||||
# which causes the branch to fail. fail_all then raises its own error.
|
||||
assert "failed" in result.error.lower()
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
"""Tests for framework/tools/flowchart_utils.py."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
from framework.tools.flowchart_utils import (
|
||||
FLOWCHART_FILENAME,
|
||||
FLOWCHART_TYPES,
|
||||
classify_flowchart_node,
|
||||
generate_fallback_flowchart,
|
||||
load_flowchart_file,
|
||||
save_flowchart_file,
|
||||
synthesize_draft_from_runtime,
|
||||
)
|
||||
|
||||
|
||||
def _make_node(
|
||||
id,
|
||||
name="Node",
|
||||
description="",
|
||||
node_type="event_loop",
|
||||
tools=None,
|
||||
input_keys=None,
|
||||
output_keys=None,
|
||||
success_criteria="",
|
||||
sub_agents=None,
|
||||
):
|
||||
"""Create a minimal node-like object matching NodeSpec interface."""
|
||||
return SimpleNamespace(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
node_type=node_type,
|
||||
tools=tools or [],
|
||||
input_keys=input_keys or [],
|
||||
output_keys=output_keys or [],
|
||||
success_criteria=success_criteria,
|
||||
sub_agents=sub_agents or [],
|
||||
)
|
||||
|
||||
|
||||
def _make_edge(source, target, condition="on_success", description=""):
|
||||
"""Create a minimal edge-like object matching EdgeSpec interface."""
|
||||
return SimpleNamespace(
|
||||
source=source,
|
||||
target=target,
|
||||
condition=SimpleNamespace(value=condition),
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
def _make_goal(
|
||||
name="Test Goal", description="A test goal", success_criteria=None, constraints=None
|
||||
):
|
||||
"""Create a minimal goal-like object matching Goal interface."""
|
||||
return SimpleNamespace(
|
||||
name=name,
|
||||
description=description,
|
||||
success_criteria=success_criteria or [],
|
||||
constraints=constraints or [],
|
||||
)
|
||||
|
||||
|
||||
def _make_graph(nodes, edges, entry_node=None, terminal_nodes=None):
|
||||
"""Create a minimal graph-like object matching GraphSpec interface."""
|
||||
return SimpleNamespace(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
entry_node=entry_node or (nodes[0].id if nodes else ""),
|
||||
terminal_nodes=terminal_nodes or [],
|
||||
)
|
||||
|
||||
|
||||
class TestClassifyFlowchartNode:
|
||||
"""Test flowchart node classification logic."""
|
||||
|
||||
def test_first_node_is_start(self):
|
||||
node = {"id": "n1", "node_type": "event_loop", "tools": []}
|
||||
result = classify_flowchart_node(node, 0, 3, [], set())
|
||||
assert result == "start"
|
||||
|
||||
def test_terminal_node(self):
|
||||
node = {"id": "n3", "node_type": "event_loop", "tools": []}
|
||||
edges = [{"source": "n1", "target": "n3"}]
|
||||
result = classify_flowchart_node(node, 2, 3, edges, {"n3"})
|
||||
assert result == "terminal"
|
||||
|
||||
def test_gcu_node_is_browser(self):
|
||||
node = {"id": "n2", "node_type": "gcu", "tools": []}
|
||||
edges = [{"source": "n1", "target": "n2"}]
|
||||
result = classify_flowchart_node(node, 1, 3, edges, set())
|
||||
assert result == "browser"
|
||||
|
||||
def test_subprocess_node(self):
|
||||
node = {"id": "n2", "node_type": "event_loop", "tools": [], "sub_agents": ["sub1"]}
|
||||
edges = [{"source": "n1", "target": "n2"}, {"source": "n2", "target": "n3"}]
|
||||
result = classify_flowchart_node(node, 1, 3, edges, set())
|
||||
assert result == "subprocess"
|
||||
|
||||
def test_default_is_process(self):
|
||||
node = {"id": "n2", "node_type": "event_loop", "tools": [], "description": "do stuff"}
|
||||
edges = [{"source": "n1", "target": "n2"}, {"source": "n2", "target": "n3"}]
|
||||
result = classify_flowchart_node(node, 1, 3, edges, set())
|
||||
assert result == "process"
|
||||
|
||||
def test_explicit_override(self):
|
||||
node = {"id": "n2", "node_type": "event_loop", "tools": [], "flowchart_type": "database"}
|
||||
edges = [{"source": "n1", "target": "n2"}]
|
||||
result = classify_flowchart_node(node, 1, 3, edges, set())
|
||||
assert result == "database"
|
||||
|
||||
def test_decision_node_with_branching(self):
|
||||
node = {"id": "n2", "node_type": "event_loop", "tools": []}
|
||||
edges = [
|
||||
{"source": "n1", "target": "n2"},
|
||||
{"source": "n2", "target": "n3", "condition": "on_success"},
|
||||
{"source": "n2", "target": "n4", "condition": "on_failure"},
|
||||
]
|
||||
result = classify_flowchart_node(node, 1, 4, edges, set())
|
||||
assert result == "decision"
|
||||
|
||||
|
||||
class TestSynthesizeDraftFromRuntime:
|
||||
"""Test runtime graph to DraftGraph conversion."""
|
||||
|
||||
def test_basic_linear_graph(self):
|
||||
nodes = [
|
||||
_make_node("intake", "Intake"),
|
||||
_make_node("process", "Process"),
|
||||
_make_node("deliver", "Deliver"),
|
||||
]
|
||||
edges = [
|
||||
_make_edge("intake", "process"),
|
||||
_make_edge("process", "deliver"),
|
||||
]
|
||||
draft, fmap = synthesize_draft_from_runtime(
|
||||
nodes, edges, agent_name="test_agent", goal_name="Test"
|
||||
)
|
||||
|
||||
assert draft["agent_name"] == "test_agent"
|
||||
assert draft["goal"] == "Test"
|
||||
assert len(draft["nodes"]) == 3
|
||||
assert len(draft["edges"]) == 2
|
||||
assert draft["entry_node"] == "intake"
|
||||
assert "deliver" in draft["terminal_nodes"]
|
||||
|
||||
# First node should be start type
|
||||
assert draft["nodes"][0]["flowchart_type"] == "start"
|
||||
# Last node (terminal) should be terminal type
|
||||
assert draft["nodes"][2]["flowchart_type"] == "terminal"
|
||||
# Middle node should be process
|
||||
assert draft["nodes"][1]["flowchart_type"] == "process"
|
||||
|
||||
# All nodes should have shape and color
|
||||
for node in draft["nodes"]:
|
||||
assert "flowchart_shape" in node
|
||||
assert "flowchart_color" in node
|
||||
|
||||
# Flowchart map should be identity
|
||||
assert fmap == {"intake": ["intake"], "process": ["process"], "deliver": ["deliver"]}
|
||||
|
||||
# Legend should contain all types
|
||||
assert draft["flowchart_legend"] == {
|
||||
k: {"shape": v["shape"], "color": v["color"]} for k, v in FLOWCHART_TYPES.items()
|
||||
}
|
||||
|
||||
def test_graph_with_sub_agents(self):
|
||||
nodes = [
|
||||
_make_node("main", "Main", sub_agents=["helper"]),
|
||||
_make_node("helper", "Helper"),
|
||||
]
|
||||
edges = [_make_edge("main", "helper")]
|
||||
draft, fmap = synthesize_draft_from_runtime(nodes, edges)
|
||||
|
||||
# Sub-agent edges should be added
|
||||
assert len(draft["edges"]) > 1
|
||||
|
||||
# Helper should be grouped under main in the flowchart map
|
||||
assert "helper" not in fmap
|
||||
assert fmap["main"] == ["main", "helper"]
|
||||
|
||||
|
||||
class TestFlowchartFilePersistence:
|
||||
"""Test save/load of flowchart.json."""
|
||||
|
||||
def test_save_and_load(self, tmp_path):
|
||||
draft = {"agent_name": "test", "nodes": [], "edges": []}
|
||||
fmap = {"n1": ["n1"]}
|
||||
|
||||
save_flowchart_file(tmp_path, draft, fmap)
|
||||
loaded_draft, loaded_map = load_flowchart_file(tmp_path)
|
||||
|
||||
assert loaded_draft == draft
|
||||
assert loaded_map == fmap
|
||||
|
||||
def test_load_missing_file(self, tmp_path):
|
||||
draft, fmap = load_flowchart_file(tmp_path)
|
||||
assert draft is None
|
||||
assert fmap is None
|
||||
|
||||
def test_load_none_path(self):
|
||||
draft, fmap = load_flowchart_file(None)
|
||||
assert draft is None
|
||||
assert fmap is None
|
||||
|
||||
def test_save_none_path(self):
|
||||
# Should not raise
|
||||
save_flowchart_file(None, {}, {})
|
||||
|
||||
|
||||
class TestGenerateFallbackFlowchart:
|
||||
"""Test the main entry point for fallback generation."""
|
||||
|
||||
def test_generates_file_when_missing(self, tmp_path):
|
||||
nodes = [
|
||||
_make_node("n1", "Start Node"),
|
||||
_make_node("n2", "End Node"),
|
||||
]
|
||||
edges = [_make_edge("n1", "n2")]
|
||||
graph = _make_graph(nodes, edges, entry_node="n1", terminal_nodes=["n2"])
|
||||
goal = _make_goal()
|
||||
|
||||
generate_fallback_flowchart(graph, goal, tmp_path)
|
||||
|
||||
flowchart_path = tmp_path / FLOWCHART_FILENAME
|
||||
assert flowchart_path.exists()
|
||||
|
||||
data = json.loads(flowchart_path.read_text())
|
||||
assert data["original_draft"]["agent_name"] == tmp_path.name
|
||||
assert data["original_draft"]["goal"] == "A test goal"
|
||||
assert data["flowchart_map"] is not None
|
||||
# Entry/terminal from GraphSpec should be used
|
||||
assert data["original_draft"]["entry_node"] == "n1"
|
||||
assert "n2" in data["original_draft"]["terminal_nodes"]
|
||||
|
||||
def test_skips_when_file_exists(self, tmp_path):
|
||||
# Pre-create a flowchart.json
|
||||
existing = {"original_draft": {"agent_name": "existing"}, "flowchart_map": {}}
|
||||
(tmp_path / FLOWCHART_FILENAME).write_text(json.dumps(existing))
|
||||
|
||||
nodes = [_make_node("n1", "Node")]
|
||||
graph = _make_graph(nodes, [], entry_node="n1")
|
||||
goal = _make_goal()
|
||||
|
||||
generate_fallback_flowchart(graph, goal, tmp_path)
|
||||
|
||||
# Should not have been overwritten
|
||||
data = json.loads((tmp_path / FLOWCHART_FILENAME).read_text())
|
||||
assert data["original_draft"]["agent_name"] == "existing"
|
||||
|
||||
def test_handles_errors_gracefully(self, tmp_path):
|
||||
# Pass an invalid path (file, not directory)
|
||||
fake_path = tmp_path / "not_a_dir.txt"
|
||||
fake_path.write_text("hello")
|
||||
|
||||
graph = _make_graph([], [])
|
||||
goal = _make_goal()
|
||||
|
||||
# Should not raise
|
||||
generate_fallback_flowchart(graph, goal, fake_path)
|
||||
|
||||
def test_enriches_with_goal_metadata(self, tmp_path):
|
||||
nodes = [_make_node("n1", "Node")]
|
||||
graph = _make_graph(nodes, [], entry_node="n1")
|
||||
goal = _make_goal(
|
||||
description="Find bugs",
|
||||
success_criteria=[SimpleNamespace(description="All bugs found")],
|
||||
constraints=[SimpleNamespace(description="No false positives")],
|
||||
)
|
||||
|
||||
generate_fallback_flowchart(graph, goal, tmp_path)
|
||||
|
||||
data = json.loads((tmp_path / FLOWCHART_FILENAME).read_text())
|
||||
assert data["original_draft"]["goal"] == "Find bugs"
|
||||
assert data["original_draft"]["success_criteria"] == ["All bugs found"]
|
||||
assert data["original_draft"]["constraints"] == ["No false positives"]
|
||||
@@ -3,12 +3,16 @@ Tests for core GraphExecutor execution paths.
|
||||
Focused on minimal success and failure scenarios.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeResult, NodeSpec
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
|
||||
# ---- Dummy runtime (no real logging) ----
|
||||
@@ -25,6 +29,14 @@ class DummyRuntime:
|
||||
pass
|
||||
|
||||
|
||||
class DummyMemory:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
def read_all(self):
|
||||
return self._data
|
||||
|
||||
|
||||
# ---- Fake node that always succeeds ----
|
||||
class SuccessNode:
|
||||
def validate_input(self, ctx):
|
||||
@@ -245,3 +257,61 @@ async def test_executor_no_events_without_event_bus():
|
||||
result = await executor.execute(graph=graph, goal=goal)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
def test_write_progress_uses_atomic_write_and_updates_state(tmp_path, monkeypatch):
|
||||
runtime = DummyRuntime()
|
||||
executor = GraphExecutor(runtime=runtime, storage_path=tmp_path)
|
||||
state_path = tmp_path / "state.json"
|
||||
state_path.write_text(json.dumps({"entry_point": "primary"}), encoding="utf-8")
|
||||
memory = DummyMemory({"foo": "bar"})
|
||||
|
||||
called = {}
|
||||
|
||||
def recording_atomic_write(path, *args, **kwargs):
|
||||
called["path"] = path
|
||||
return atomic_write(path, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("framework.graph.executor.atomic_write", recording_atomic_write)
|
||||
|
||||
executor._write_progress(
|
||||
current_node="node-b",
|
||||
path=["node-a", "node-b"],
|
||||
memory=memory,
|
||||
node_visit_counts={"node-a": 1, "node-b": 1},
|
||||
)
|
||||
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
assert called["path"] == state_path
|
||||
assert state["entry_point"] == "primary"
|
||||
assert state["progress"]["current_node"] == "node-b"
|
||||
assert state["progress"]["path"] == ["node-a", "node-b"]
|
||||
assert state["progress"]["node_visit_counts"] == {"node-a": 1, "node-b": 1}
|
||||
assert state["progress"]["steps_executed"] == 2
|
||||
assert state["memory"] == {"foo": "bar"}
|
||||
assert state["memory_keys"] == ["foo"]
|
||||
assert "updated_at" in state["timestamps"]
|
||||
|
||||
|
||||
def test_write_progress_logs_warning_on_atomic_write_failure(tmp_path, monkeypatch, caplog):
|
||||
runtime = DummyRuntime()
|
||||
executor = GraphExecutor(runtime=runtime, storage_path=tmp_path)
|
||||
state_path = tmp_path / "state.json"
|
||||
state_path.write_text(json.dumps({"entry_point": "primary"}), encoding="utf-8")
|
||||
memory = DummyMemory({"foo": "bar"})
|
||||
|
||||
def failing_atomic_write(*args, **kwargs):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("framework.graph.executor.atomic_write", failing_atomic_write)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
executor._write_progress(
|
||||
current_node="node-b",
|
||||
path=["node-a", "node-b"],
|
||||
memory=memory,
|
||||
node_visit_counts={"node-a": 1, "node-b": 1},
|
||||
)
|
||||
|
||||
assert "Failed to persist progress state to" in caplog.text
|
||||
assert str(state_path) in caplog.text
|
||||
|
||||
@@ -19,7 +19,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
|
||||
from framework.llm.litellm import (
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
|
||||
LiteLLMProvider,
|
||||
_compute_retry_delay,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
|
||||
@@ -72,6 +76,20 @@ class TestLiteLLMProviderInit:
|
||||
)
|
||||
assert provider.api_base == "https://proxy.example/v1"
|
||||
|
||||
def test_init_openrouter_defaults_api_base(self):
|
||||
"""OpenRouter should default to the official OpenAI-compatible endpoint."""
|
||||
provider = LiteLLMProvider(model="openrouter/x-ai/grok-4.20-beta", api_key="my-key")
|
||||
assert provider.api_base == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_init_openrouter_keeps_custom_api_base(self):
|
||||
"""Explicit api_base should win over OpenRouter defaults."""
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/x-ai/grok-4.20-beta",
|
||||
api_key="my-key",
|
||||
api_base="https://proxy.example/v1",
|
||||
)
|
||||
assert provider.api_base == "https://proxy.example/v1"
|
||||
|
||||
def test_init_ollama_no_key_needed(self):
|
||||
"""Test that Ollama models don't require API key."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
@@ -192,6 +210,34 @@ class TestToolConversion:
|
||||
assert result["function"]["parameters"]["properties"]["query"]["type"] == "string"
|
||||
assert result["function"]["parameters"]["required"] == ["query"]
|
||||
|
||||
def test_parse_tool_call_arguments_repairs_truncated_json(self):
|
||||
"""Truncated JSON fragments should be repaired into valid tool inputs."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
parsed = provider._parse_tool_call_arguments(
|
||||
(
|
||||
'{"question":"What story structure should the agent use?",'
|
||||
'"options":["3-act structure","Beginning-Middle-End","Random paragraph"'
|
||||
),
|
||||
"ask_user",
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
"question": "What story structure should the agent use?",
|
||||
"options": [
|
||||
"3-act structure",
|
||||
"Beginning-Middle-End",
|
||||
"Random paragraph",
|
||||
],
|
||||
}
|
||||
|
||||
def test_parse_tool_call_arguments_raises_when_unrepairable(self):
|
||||
"""Completely invalid JSON should fail fast instead of producing _raw loops."""
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
|
||||
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
|
||||
|
||||
|
||||
class TestAnthropicProviderBackwardCompatibility:
|
||||
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
|
||||
@@ -682,6 +728,315 @@ class TestMiniMaxStreamFallback:
|
||||
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
|
||||
|
||||
|
||||
class TestOpenRouterToolCompatFallback:
|
||||
"""OpenRouter models should fall back when native tool use is unavailable."""
|
||||
|
||||
def teardown_method(self):
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_falls_back_to_json_tool_emulation(self, mock_acompletion):
|
||||
"""OpenRouter tool-use 404s should emit synthetic ToolCallEvents instead of errors."""
|
||||
from framework.llm.stream_events import FinishEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"num_results": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
'{"assistant_response":"","tool_calls":['
|
||||
'{"name":"web_search","arguments":'
|
||||
'{"query":"Python 3.13 release notes","num_results":3}}'
|
||||
"]}"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 18
|
||||
compat_response.usage.completion_tokens = 9
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
"that support tool use. To learn more about provider routing, "
|
||||
'visit: https://openrouter.ai/docs/guides/routing/provider-selection",'
|
||||
'"code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Search for the Python 3.13 release notes."}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=256,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "web_search"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"query": "Python 3.13 release notes",
|
||||
"num_results": 3,
|
||||
}
|
||||
assert tool_calls[0].tool_use_id.startswith("openrouter_compat_")
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "tool_calls"
|
||||
assert finish_events[0].input_tokens == 18
|
||||
assert finish_events[0].output_tokens == 9
|
||||
|
||||
assert mock_acompletion.call_count == 2
|
||||
first_call = mock_acompletion.call_args_list[0].kwargs
|
||||
assert first_call["stream"] is True
|
||||
assert "tools" in first_call
|
||||
|
||||
second_call = mock_acompletion.call_args_list[1].kwargs
|
||||
assert "tools" not in second_call
|
||||
assert "Tool compatibility mode is active" in second_call["messages"][0]["content"]
|
||||
assert provider.model in OPENROUTER_TOOL_COMPAT_MODEL_CACHE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_parses_textual_tool_calls_and_uses_cache(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""Textual tool-call markers should become ToolCallEvents and skip repeat probing."""
|
||||
from framework.llm.stream_events import ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="ask_user_multiple",
|
||||
description="Ask the user a multiple-choice question",
|
||||
parameters={
|
||||
"properties": {
|
||||
"options": {"type": "array"},
|
||||
"question": {"type": "string"},
|
||||
"prompt": {"type": "string"},
|
||||
},
|
||||
"required": ["options", "question", "prompt"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
"<|tool_call_start|>"
|
||||
"[ask_user_multiple(options=['Quartet Collaborator', 'Project Advisor'], "
|
||||
"question='Who are you?', prompt='Who are you?')]"
|
||||
"<|tool_call_end|>"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 10
|
||||
compat_response.usage.completion_tokens = 5
|
||||
|
||||
call_state = {"count": 0}
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
call_state["count"] += 1
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
first_events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Who are you?"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
first_events.append(event)
|
||||
|
||||
tool_calls = [event for event in first_events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "ask_user_multiple"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"options": ["Quartet Collaborator", "Project Advisor"],
|
||||
"question": "Who are you?",
|
||||
"prompt": "Who are you?",
|
||||
}
|
||||
|
||||
second_events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Who are you?"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
second_events.append(event)
|
||||
|
||||
second_tool_calls = [event for event in second_events if isinstance(event, ToolCallEvent)]
|
||||
assert len(second_tool_calls) == 1
|
||||
assert mock_acompletion.call_count == 3
|
||||
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
|
||||
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
|
||||
assert "stream" not in mock_acompletion.call_args_list[2].kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_parses_plain_text_tool_call_lines(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""Plain textual tool-call lines should execute as tools, not user-visible text."""
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="ask_user",
|
||||
description="Ask the user a single multiple-choice question",
|
||||
parameters={
|
||||
"properties": {
|
||||
"question": {"type": "string"},
|
||||
"options": {"type": "array"},
|
||||
},
|
||||
"required": ["question", "options"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = (
|
||||
"Queen has been loaded. It's ready to assist with your planning needs.\n\n"
|
||||
"ask_user('What would you like to do?', ['Define a new agent', "
|
||||
"'Diagnose an existing agent', 'Explore tools'])"
|
||||
)
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 11
|
||||
compat_response.usage.completion_tokens = 7
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
system="Use tools when needed.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "ask_user"
|
||||
assert tool_calls[0].tool_input == {
|
||||
"question": "What would you like to do?",
|
||||
"options": ["Define a new agent", "Diagnose an existing agent", "Explore tools"],
|
||||
}
|
||||
|
||||
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
||||
assert len(text_events) == 1
|
||||
assert "ask_user(" not in text_events[0].snapshot
|
||||
assert text_events[0].snapshot == (
|
||||
"Queen has been loaded. It's ready to assist with your planning needs."
|
||||
)
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "tool_calls"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_tool_compat_treats_non_json_as_plain_text(self, mock_acompletion):
|
||||
"""If fallback output is not valid JSON, preserve it as assistant text."""
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
|
||||
)
|
||||
]
|
||||
|
||||
compat_response = MagicMock()
|
||||
compat_response.choices = [MagicMock()]
|
||||
compat_response.choices[0].message.content = "I can answer directly without tools."
|
||||
compat_response.choices[0].finish_reason = "stop"
|
||||
compat_response.model = provider.model
|
||||
compat_response.usage.prompt_tokens = 12
|
||||
compat_response.usage.completion_tokens = 6
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise RuntimeError(
|
||||
'OpenrouterException - {"error":{"message":"No endpoints found '
|
||||
'that support tool use.","code":404}}'
|
||||
)
|
||||
return compat_response
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Say hello."}],
|
||||
system="Be concise.",
|
||||
tools=tools,
|
||||
max_tokens=128,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
||||
assert len(text_events) == 1
|
||||
assert text_events[0].snapshot == "I can answer directly without tools."
|
||||
assert not any(isinstance(event, ToolCallEvent) for event in events)
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "stop"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentRunner._is_local_model — parameterized tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -338,6 +338,69 @@ class TestLLMJudgeBackwardCompatibility:
|
||||
assert call_kwargs["model"] == "claude-haiku-4-5-20251001"
|
||||
assert call_kwargs["max_tokens"] == 500
|
||||
|
||||
def test_openai_fallback_uses_litellm_provider(self, monkeypatch):
|
||||
"""When OPENAI_API_KEY is set, evaluate() should use a LiteLLM-based provider."""
|
||||
# Force the OpenAI fallback path (no injected provider, no Anthropic key)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai")
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
|
||||
# Stub LiteLLMProvider so we don't call the real API; record what judge passes through
|
||||
captured_calls: list[dict] = []
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, model: str = "gpt-4o-mini"):
|
||||
self.model = model
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages,
|
||||
system="",
|
||||
tools=None,
|
||||
max_tokens=1024,
|
||||
response_format=None,
|
||||
json_mode=False,
|
||||
max_retries=None,
|
||||
):
|
||||
captured_calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"system": system,
|
||||
"max_tokens": max_tokens,
|
||||
"json_mode": json_mode,
|
||||
"model": self.model,
|
||||
}
|
||||
)
|
||||
|
||||
class _Resp:
|
||||
def __init__(self, content: str):
|
||||
self.content = content
|
||||
|
||||
# Minimal response object with a content attribute
|
||||
return _Resp('{"passes": true, "explanation": "OK"}')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"framework.llm.litellm.LiteLLMProvider",
|
||||
DummyProvider,
|
||||
)
|
||||
|
||||
judge = LLMJudge()
|
||||
result = judge.evaluate(
|
||||
constraint="no-hallucination",
|
||||
source_document="The sky is blue.",
|
||||
summary="The sky is blue.",
|
||||
criteria="Summary must only contain facts from source",
|
||||
)
|
||||
|
||||
# Judge should have used our stub once and returned the stub's JSON result
|
||||
assert result["passes"] is True
|
||||
assert result["explanation"] == "OK"
|
||||
assert len(captured_calls) == 1
|
||||
|
||||
call = captured_calls[0]
|
||||
assert call["model"] == "gpt-4o-mini"
|
||||
assert call["max_tokens"] == 500
|
||||
assert call["json_mode"] is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LLMJudge Integration Pattern Tests
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user