Compare commits

...

137 Commits

Author SHA1 Message Date
Timothy d78473ff20 chore: experiment 2026-03-24 16:02:32 -07:00
Timothy @aden 8ecb728148 Merge pull request #6784 from aden-hive/fix/pin-litellm-1.81.7
security: pin litellm==1.81.7 to block supply chain attack
2026-03-24 09:53:40 -07:00
Timothy 4a2141bce9 chore: regenerate uv.lock with litellm==1.81.7 pin
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-24 09:48:47 -07:00
Timothy 3b4d6e4602 security: pin litellm==1.81.7 to block supply chain attack
litellm>=1.82.7 contains a malicious .pth file that auto-executes at
Python startup and exfiltrates env vars, SSH keys, cloud credentials,
and CI/CD secrets to an attacker-controlled domain.

Pin to last known-safe version (currently installed). Unpin once a
verified-clean upstream release is available.

Closes #6783

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-24 09:42:42 -07:00
Timothy @aden 89ccc664bd Merge pull request #6574 from Antiarin/feat/mcp-registry-core
feat(mcp-registry): add MCPRegistry core module (#6349)
2026-03-24 07:40:35 -07:00
Bryan @ Aden 4872c01886 Merge pull request #6777 from sundaram2021/fix/missing-antigravity-option-in-windows-powershell
fix: missing antigravity and minimax plan option in powershell
2026-03-24 07:35:39 -07:00
Sundaram Kumar Jha 4620380341 fix: missing antigravity and minimax plan option in powershell 2026-03-24 09:26:07 +05:30
Richard Tang fca2deb980 chore: update documentation 2026-03-23 20:35:26 -07:00
RichardTang-Aden d7ce923ca6 Merge pull request #1586 from rhythmtaneja/improve-eventbus-logging
Improve EventBus handler error logging to include traceback
2026-03-23 20:17:51 -07:00
Richard Tang 403b47db61 chore: lint 2026-03-23 20:05:29 -07:00
Richard Tang 0d0e78579f chore: lint 2026-03-23 18:09:15 -07:00
RichardTang-Aden 447bfdfab8 Merge pull request #6763 from Leayxz/micro-fix/files_names_conflicts
micro-fix: make test filenames unique to avoid pytest import conflicts / error test_structure
2026-03-23 17:35:16 -07:00
RichardTang-Aden c77d21e393 Merge pull request #6761 from Leayxz/micro-fix/remove_obsolete_PushoverClient_tests
micro-fix: remove obsolete _PushoverClient tests
2026-03-23 17:34:49 -07:00
RichardTang-Aden 6ded508b4d Merge pull request #6774 from Leayxz/micro-fix/rename_schema_discovery
micro-fix: rename schema discovery to avoid pytest collection
2026-03-23 17:34:07 -07:00
RichardTang-Aden 75f8bf5696 Merge pull request #6743 from sundaram2021/fix/codex-oauth-stdin-select-windows
fix: windows Codex OAuth browser launch and manual fallback
2026-03-23 16:52:56 -07:00
Leandro Rodrigues 62fc02220b micro-fix: rename schema discovery to avoid pytest collection
- The file `tools/test_schema_discovery.py` was being incorrectly collected by pytest as a test module
- Since the file is actually a standalone script, this caused import errors during test collection
- Rename the file to remove the `test_` prefix so pytest no longer treats it as a test file
- Pytest test discovery no longer includes the script, eliminating the import error and restoring a clean test run
2026-03-23 20:51:18 -03:00
Richard Tang 5d4f279646 test: add real integration test for MCPRegistry → AgentRunner path 2026-03-23 15:44:54 -07:00
Bryan @ Aden 920a840756 Merge pull request #6772 from sundaram2021/fix/setup-worker-model-on-windows
fix(windows): use shared uv discovery in setup_worker_model.ps1
2026-03-23 15:44:48 -07:00
Sundaram Kumar Jha 8680a35c39 fix(powershell): use shared uv discovery in setup_worker_model 2026-03-24 03:57:07 +05:30
Leandro Rodrigues c9134cfd91 micro-fix: make test filenames unique to avoid pytest import conflicts
- multiple test files shared the same module name "test_structure.py"
- this cause pytest import mismatches during collection
- renamed test files to "test_email_reply_agent" and "test_meeting_scheduler"
- eliminated module name collisions and fixed test discovery
2026-03-23 16:13:31 -03:00
Leandro Rodrigues 55ce751385 micro-fix: remove obsolete _PushoverClient tests
- the test suite still referenced _PushoverClient, which no longer exists
- this caused import errors and failing pytest runs
- removed all tests related to _PushoverClient
- fixed pytest execution errors
- removed dead test code
- ensured test coverage reflects the current implementation
2026-03-23 15:54:50 -03:00
Timothy @aden aca2dfb536 Merge pull request #5892 from nikhilvarmakandula/feat/openmeteo-weather-tool
feat(tools): add Open-Meteo weather tool — free real-time weather, no API key required
2026-03-23 10:30:59 -07:00
Antiarin d11f539209 Merge branch 'main' into feat/mcp-registry-core 2026-03-23 11:29:47 +05:30
Antiarin 64a223353a fix: harden MCPConnectionManager with timeouts, SSE health checks, and failure handling
Add 30s transition timeouts to prevent deadlocks on stuck connections.
Split SSE from HTTP in health_check: SSE uses client.list_tools() instead
of hitting /health (SSE servers use event-stream protocol, not REST).
Add has_connection() for MCPRegistry health check integration. Handle
disconnect failures in release, reconnect, and cleanup_all. Guard
reconnect against refcount dropping to zero mid-reconnect.
2026-03-23 11:13:27 +05:30
Antiarin 2d154c2db6 test: add tests for MCPRegistry, runner integration, and load_registry_servers
Covers install/add_local/remove/enable/disable, resolve_for_agent selection
precedence, health checks with pooled connections, cache fallback (defect 1),
SSE health check (defect 2), tomllib version parsing (defect 3), JSON type
validation for mcp_registry.json fields, malformed JSON error handling,
structured log emission, and retry-on-zero-tools behavior.
2026-03-23 11:13:27 +05:30
Antiarin a00c934d9d feat: add MCPRegistry core module with framework integration
Local state management for installed MCP servers in ~/.hive/mcp_registry/.
Supports install from registry index, add_local for running servers,
resolve_for_agent with include/tags/exclude/profile/max_tools/versions
selection, health checks via MCPConnectionManager, and JSON type
validation at the mcp_registry.json boundary.

Integration points: AgentRunner, queen orchestrator, credential tester
all load mcp_registry.json with error handling. ToolRegistry gains
load_registry_servers() with retry and structured DX-4 logging.
2026-03-23 11:13:27 +05:30
Sundaram Kumar Jha 18bee9cb90 Add Codex OAuth Windows regression tests 2026-03-23 10:40:51 +05:30
Sundaram Kumar Jha c1664e47e5 Fix Windows Codex OAuth URL and stdin handling 2026-03-23 10:40:30 +05:30
Emmanuel Nwanguma 2cb972fc5a fix(runner): replace print() with logger.warning() for credential warnings (#6577)
Fixes #6484

- Replace 8 raw print() calls with logger.warning() in runner.py
- Uses lazy % formatting instead of f-strings
- Warnings about missing tokens/API keys now go through logging framework
- Visible in log files when agents run headlessly
2026-03-22 18:24:42 +08:00
Emmanuel Nwanguma 0bd841ce01 fix(credentials): replace bare except Exception clauses with specific handlers (#6592)
Fixes #6481

- credential_tester/agent.py: 4 bare excepts replaced
- credentials/setup.py: 6 bare excepts replaced
- queen_memory.py: 2 bare excepts replaced (2 already had proper logging)
- Expected errors (ImportError, OSError, KeyError) logged at DEBUG
- Unexpected errors logged at WARNING with exc_info=True
- Same two-tier pattern as PR #6153 (key_storage.py)
2026-03-22 18:16:14 +08:00
Samer Attrah 88ec4b7e64 fix: improve tool_registry error handling with stack traces and context (#6518)
* fix: improve tool_registry error handling with stack traces and context

When tool execution fails, errors now include:
- Stack traces for debugging
- Tool name, tool_use_id, and inputs in error logs
- Same behavior for both sync and async tools

Fixes #2447

* fix: use exc_info=True and truncate inputs in tool error logs

- Replace traceback.format_exc() with exc_info=True (codebase convention)
- Truncate tool inputs to 500 chars to prevent log flooding
- Add test for input truncation
2026-03-22 18:01:28 +08:00
Sundaram Kumar Jha 27d5061d97 micro-fix: quickstart dashboard auto-launch for PowerShell (#6655)
* Fix quickstart dashboard auto-launch on Windows

* chore: refresh locks

* fix: gate quickstart hive shim to Git Bash

* chore: revert unrelated frontend lockfile churn
2026-03-22 16:21:02 +08:00
Sundaram Kumar Jha a2cd96a1a7 docs: document OpenRouter and Hive LLM provider setup (#6644)
* docs(llm): document OpenRouter and Hive LLM setup

* docs(contributing): add OpenRouter and Hive LLM guidance
2026-03-22 10:12:44 +08:00
Hundao 07b82a51f6 fix(examples): use __file__ relative path for mcp_servers.json copy (#6677)
Fixes #1669
2026-03-22 08:26:13 +08:00
Timothy @aden 3e1282b31e Merge pull request #6682 from aden-hive/feat/image-capabilities
Release / Create Release (push) Waiting to run
feat: image capabilities — upload, screenshot passthrough, vision detection & fallback, aria refs
2026-03-20 21:25:37 -07:00
Timothy 736756b257 chore: fix test 2026-03-20 21:22:29 -07:00
Timothy 90efe7009d chore: lint 2026-03-20 21:13:22 -07:00
Timothy 4adb369bde chore: lint 2026-03-20 21:12:03 -07:00
Timothy d4a30eb2f3 feat: image model fallback 2026-03-20 20:18:07 -07:00
Timothy 94bb4a2984 Merge branch 'main' into feat/image-capabilities 2026-03-20 18:42:55 -07:00
Timothy 648bad26ed feat: user input image content 2026-03-20 18:40:28 -07:00
RichardTang-Aden f0c7470f3d Merge pull request #6663 from sundaram2021/fix/missing-minimax-option-on-windows
fix: minimax option in powershell quickstart
2026-03-20 17:00:11 -07:00
RichardTang-Aden fe533b72a6 Merge pull request #6648 from levxn/main
Antigravity subscription support as an LLM provider
2026-03-20 16:52:38 -07:00
Richard Tang e581767cab chore: ruff lint 2026-03-20 16:50:50 -07:00
Richard Tang 0663ee5950 feat: validate the existing credentials before auth 2026-03-20 16:45:56 -07:00
Richard Tang 4b97baa34b feat: native google oauth for antigravity support 2026-03-20 16:40:15 -07:00
levxn a89296d397 lint fix 2026-03-21 02:35:09 +05:30
Levin d568912ba2 Merge branch 'aden-hive:main' into main 2026-03-21 01:32:13 +05:30
Levin c4d7980058 Merge pull request #1 from levxn/subscription/antigravity
Subscription/antigravity
2026-03-21 01:30:27 +05:30
Timothy @aden 8549fe8238 Merge pull request #6635 from vakrahul/fix/skill-structured-errors-6366
feat: structured skill error codes and diagnostics (closes #6366)
2026-03-20 12:45:35 -07:00
levxn 2b8d85bb95 fixing tool calling issue, antigravity's model's expected thought_signature in functioncall parts, else faces 400 error stating invalid arguments 2026-03-20 23:26:50 +05:30
levxn 07f7801166 test v1 2026-03-20 22:32:30 +05:30
Levin 1f12a45151 Merge branch 'aden-hive:main' into main 2026-03-20 22:01:22 +05:30
Arshad Uzzama Shaik 936e02e8e6 fix(security): prevent symlink-based sandbox escape in get_secure_path (closes #1167) (#5635)
* fix(security): prevent symlink-based sandbox escape in get_secure_path (closes #1167)

* style: apply ruff formatting to tools to satisfy CI

---------

Co-authored-by: Arshad Shaik <arshad.shaik@violetis.ai>
2026-03-20 19:16:47 +08:00
Hundao d59fe1e109 fix(graph): remove dead check_constraint placeholder (#6660)
Never called anywhere in the codebase. Constraints are enforced
via prompt context, not runtime validation.
2026-03-20 18:44:18 +08:00
Sundaram Kumar Jha 274318d3e5 fix: minimax option in powershell quickstart 2026-03-20 15:33:26 +05:30
Anurag Kumar 0f0884c2e0 fix(tools): handle non-HTML content and add PDF URL support (#438)
* feat(tools): add URL support to pdf_read tool

Enable pdf_read to accept both local file paths and HTTP/HTTPS URLs.
Downloads PDF content to temporary file when URL is provided, validates
content-type, and cleans up automatically after extraction.

- Detect URL inputs (http:// or https://)
- Download PDF with httpx (60s timeout)
- Validate Content-Type is application/pdf
- Use temporary file for URL-based PDFs
- Automatic cleanup in finally block
- Maintains backward compatibility with local paths

Completes the workflow: web_scrape error on PDF → pdf_read from URL

* test(tools): Add test coverage for new features in web_scrape and pdf_read tools

* style: fix lint issues in pdf_read URL support

---------

Co-authored-by: Anurag <anuragkr-codes@users.noreply.github.com>
Co-authored-by: hundao <alchemy_wimp@hotmail.com>
2026-03-20 16:36:25 +08:00
Timothy @aden 764012c598 Merge pull request #6652 from aden-hive/feature/absolutely-parallel
Release / Create Release (push) Waiting to run
fix: parallel subagent execution display, session resume bugs, and GCU termination
2026-03-19 20:21:47 -07:00
Timothy fd4dc1a69a fix: google_sheets JSON parse error before credentials check
Move _get_client() before JSON deserialization so missing-credentials
errors aren't masked by input validation. Wrap json.loads in try/except
for non-JSON string inputs.
2026-03-19 20:13:18 -07:00
Timothy 377cd39c2a chore: lint 2026-03-19 20:07:42 -07:00
Timothy e92caeef24 fix: line too long in google_sheets_tool
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 20:06:31 -07:00
Timothy @aden b7e6226478 Update asset link in README.md 2026-03-19 19:41:19 -07:00
Timothy a995818db2 fix: subagent bubble boundary 2026-03-19 17:57:33 -07:00
Timothy 0772b4d300 feat: better subagent interleave logic 2026-03-19 16:58:34 -07:00
Timothy 684e0d8dc6 fix: no memory consolidation for worker 2026-03-19 16:58:00 -07:00
Timothy d284c5d790 feat: parallel execution display 2026-03-19 15:25:21 -07:00
Timothy 7a9b9666c4 fix: refresh system prompt with preamble 2026-03-19 15:25:04 -07:00
Timothy a852cb91bf fix: non-blocking memory consolidation 2026-03-19 15:24:30 -07:00
Timothy 2f21e9eb4b fix: session reload preamble 2026-03-19 15:24:12 -07:00
Timothy 8390ef8731 fix: google sheet tool support json string input 2026-03-19 15:23:31 -07:00
levxn 8d21479c24 fixing lint errors 2026-03-20 02:34:58 +05:30
levxn 965dec3ba1 fixing errors, finalising credential fetch (client id and secret) properly in fallback paths 2026-03-20 02:32:42 +05:30
Timothy d4b54446be Merge branch 'main' into feat/image-capabilities 2026-03-19 11:12:33 -07:00
Levin 7992b862c2 Merge branch 'aden-hive:main' into main 2026-03-19 22:16:10 +05:30
Ananya Verma 44b3e0eaa2 Configure pytest to ignore DeprecationWarning (#1727)
Add pytest configuration to ignore specific warnings.
2026-03-19 23:17:50 +08:00
levxn f480fc2b94 oauth creds for antigravity picked properly 2026-03-19 20:26:42 +05:30
vakrahul 2844dbf19f feat: structured skill error codes and diagnostics (closes #6366) 2026-03-19 13:18:18 +05:30
Timothy @aden 22b7e4b0c3 Merge pull request #6624 from aden-hive/feature/agent-skills
Release / Create Release (push) Waiting to run
feat: agent skills system and observability improvements
2026-03-18 20:28:34 -07:00
Timothy 5413833a69 fix: tool test 2026-03-18 20:20:32 -07:00
bryan 02e1a4584a fix: autolaunch gui (windows) 2026-03-18 20:15:25 -07:00
Timothy 520840b1dd fix: no immediate run digest 2026-03-18 20:14:20 -07:00
bryan ee96147336 feat: autolaunch gui (mac) 2026-03-18 20:11:03 -07:00
Timothy 705cef4dc1 fix: context window display 2026-03-18 20:05:48 -07:00
Timothy ab26e64122 Merge remote-tracking branch 'origin/main' into feature/agent-skills 2026-03-18 19:41:39 -07:00
Timothy @aden f365e219cb Merge pull request #6615 from aden-hive/feat/worker-llm
feat: support separate LLM model for worker agents
2026-03-18 19:41:06 -07:00
Timothy 01621881c2 chore: lint 2026-03-18 19:40:41 -07:00
Timothy f7639f8572 fix: realtime context display 2026-03-18 19:29:31 -07:00
Timothy fc643060ce fix: better message bubble handling 2026-03-18 17:49:55 -07:00
Timothy 9aebeb181e feat: compaction debugger 2026-03-18 17:42:10 -07:00
Timothy acbbfaaa79 feat: compaction debug 2026-03-18 17:41:22 -07:00
Timothy bf170bce10 feat: enable mcp server reuse by default 2026-03-18 17:30:31 -07:00
Timothy 0a090d058b Merge remote-tracking branch 'origin/main' into feature/agent-skills 2026-03-18 17:11:12 -07:00
Timothy @aden 47bfadaad9 Merge pull request #6622 from aden-hive/fix/resume-empty-message
Fix empty queen message bubbles on session resume
2026-03-18 16:55:50 -07:00
Timothy d968dcd44c Merge branch 'main' into feature/agent-skills 2026-03-18 16:53:42 -07:00
Timothy @aden 6fdaa9ea50 Merge pull request #6534 from VasuBansal7576/codex/mcp-connection-manager-6348-draft
feat: add shared MCP connection manager
2026-03-18 16:52:44 -07:00
Timothy @aden 4d251fbdc2 Merge pull request #6531 from VasuBansal7576/codex/mcp-transports-6347-single
feat: add unix and sse MCP transports
2026-03-18 16:38:17 -07:00
Timothy 6acceed288 feat: hive debugger 2026-03-18 16:26:55 -07:00
Richard Tang 8dd1d6e3aa chore: lint 2026-03-18 16:01:32 -07:00
Timothy 1da28644a6 Merge branch 'main' into feature/agent-skills 2026-03-18 15:38:49 -07:00
Timothy 6452fe7fef fix: discord bot 2026-03-18 15:34:08 -07:00
Richard Tang acff008bd2 fix: empty message render 2026-03-18 15:26:56 -07:00
Timothy 651d6850a1 fix: bounty tracker change 2026-03-18 14:49:21 -07:00
Timothy c7fdc92594 fix: bounty script 2026-03-18 14:27:24 -07:00
Richard Tang 43602a8801 fix: trim to remove empty message 2026-03-18 13:55:57 -07:00
Timothy @aden 3da04265a6 Merge pull request #6566 from levxn/skills/context-protection
feat(skills): AS-9 and AS-10 — skill directory allowlisting and context protection for activated skills
2026-03-18 13:51:25 -07:00
Timothy @aden 4c98f0d2d0 Merge pull request #6564 from levxn/skills/resource-loading
feat(skills): AS-6 tier 3 resource loading — base_dir in catalog XML and skill dirs wired through execution stack
2026-03-18 13:50:54 -07:00
bryan d84c3364d0 chore: update to pass make test 2026-03-18 13:20:56 -07:00
Timothy @aden ae921f6cee Merge pull request #6619 from aden-hive/fix/claude-code-subscription-support
fix(llm): restore Claude Code subscription OAuth support
2026-03-18 13:08:27 -07:00
Timothy 6b506a1c08 chore: lint 2026-03-18 13:05:00 -07:00
Timothy 0c9f4fa97e fix(llm): restore Claude Code subscription (OAuth) support after Anthropic API change
Anthropic tightened OAuth validation on 2026-03-17, requiring a
specific User-Agent header and a billing integrity system block for
subscription-authenticated requests. Without these, all OAuth calls
return HTTP 400 with a generic "Error" message.

Changes:
- Add billing integrity system block (SHA-256 hash derived from first
  user message content) prepended to system messages on OAuth requests
- Set User-Agent to claude-code/<version> for OAuth sessions
- Fix OAuth header patch to detect tokens in x-api-key (not just
  Authorization) and add required beta/browser-access headers
- Set litellm.drop_params=True to prevent unsupported params like
  stream_options from leaking to Anthropic (causes 400)
- Skip stream_options entirely for Anthropic models
- Honour LITELLM_LOG env var for debug logging instead of hardcoding
  LiteLLM logger to WARNING
2026-03-18 13:02:24 -07:00
Richard Tang 95e30bc607 chore: remove old queen history endpoint 2026-03-18 12:43:30 -07:00
bryan 0f1f0090b0 chore: linter update 2026-03-18 12:41:01 -07:00
bryan c0da3bec02 feat: strip image content for non-vision models 2026-03-18 12:40:30 -07:00
bryan 9dadb5264d feat: add screenshot image passthrough to LLM 2026-03-18 12:40:18 -07:00
bryan e39e6a75cc feat: add ref system for aria snapshots 2026-03-18 12:36:51 -07:00
Richard Tang 23c66d1059 feat: worker model loading 2026-03-18 12:14:02 -07:00
Richard Tang b9d529d94e feat: support separate worker llm setup 2026-03-18 11:19:44 -07:00
levxn b799789dbe fixing lint 2026-03-18 02:15:58 +05:30
levxn 2cd73dfccc implements AS-9 and AS-10 2026-03-18 02:06:51 +05:30
levxn 57d77d5479 fixing lint 2026-03-18 01:32:24 +05:30
levxn 5814021773 skills trust gate merged properly into resource loading branch 2026-03-18 01:18:20 +05:30
levxn 4f4cc9c8ce halfway done commit 2026-03-18 00:59:35 +05:30
Timothy d9c840eee5 chore: resolve merge conflicts with feature/agent-skills
Integrate SkillsManager refactor from base branch. Trust gating (AS-13)
is now wired into SkillsManager._do_load() instead of inline in runner.py,
with the interactive flag passed through SkillsManagerConfig.
2026-03-17 11:55:11 -07:00
levxn 88253883a3 tier 3 resource loading 2026-03-17 03:30:58 +05:30
levxn 6ed6e5b286 lint fixes 2026-03-17 00:32:14 +05:30
Vasu Bansal 30bb0ad5d8 style: format MCP connection manager 2026-03-16 23:46:44 +05:30
Vasu Bansal cb0845f5ba fix: wrap MCP manager cleanup condition 2026-03-16 23:41:36 +05:30
Levin ce2525b59c Merge branch 'aden-hive:main' into skills/trust-gating 2026-03-16 23:39:27 +05:30
levxn 1f77ec3831 fixed bug introduced with change in executor.py, AS-13 along with upstream's AS-1,2,3,4,5 2026-03-16 23:38:45 +05:30
Vasu Bansal 6ab5aa8004 style: format mcp client
Apply ruff formatting to satisfy CI on the MCP transport changes.
2026-03-16 23:19:49 +05:30
Vasu Bansal 4449cd8ee8 feat: add shared MCP connection manager 2026-03-16 23:10:26 +05:30
Vasu Bansal 8b60c03a0a feat: add unix and sse MCP transports
Implements unix socket and SSE MCP transports, adds reconnect-once retry for unix/SSE, and adds focused unit coverage.
2026-03-16 23:03:44 +05:30
Levin 0e98023e40 Merge branch 'aden-hive:main' into skills/trust-gating 2026-03-16 22:23:57 +05:30
levxn 48a54b4ee2 implements AS-13, trusted gating for project level skills 2026-03-16 17:45:33 +05:30
nikhilvarmakandula 151fbd7b00 feat(tools): add Open-Meteo weather tool with no API key required 2026-03-06 00:46:18 +05:30
rhythmtaneja f88483f964 chore: trigger PR revalidation 2026-01-28 09:52:31 +05:30
rhythmtaneja b61ec8c94d Improve EventBus handler error logging by using logger.exception to include traceback 2026-01-28 00:46:23 +05:30
128 changed files with 16196 additions and 6007 deletions
+14 -4
View File
@@ -2,14 +2,22 @@ name: Bounty completed
description: Awards points and notifies Discord when a bounty PR is merged
on:
pull_request:
pull_request_target:
types: [closed]
workflow_dispatch:
inputs:
pr_number:
description: "PR number to process (for missed bounties)"
required: true
type: number
jobs:
bounty-notify:
if: >
github.event.pull_request.merged == true &&
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:')
github.event_name == 'workflow_dispatch' ||
(github.event.pull_request.merged == true &&
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:'))
runs-on: ubuntu-latest
timeout-minutes: 5
permissions:
@@ -32,6 +40,8 @@ jobs:
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
BOT_API_URL: ${{ secrets.BOT_API_URL }}
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_NUMBER: ${{ inputs.pr_number || github.event.pull_request.number }}
-126
View File
@@ -1,126 +0,0 @@
name: Link Discord account
description: Auto-creates a PR to add contributor to contributors.yml when a link-discord issue is opened
on:
issues:
types: [opened]
jobs:
link-discord:
if: contains(github.event.issue.labels.*.name, 'link-discord')
runs-on: ubuntu-latest
timeout-minutes: 2
permissions:
contents: write
issues: write
pull-requests: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Parse issue and update contributors.yml
uses: actions/github-script@v7
with:
script: |
const fs = require('fs');
const issue = context.payload.issue;
const githubUsername = issue.user.login;
// Parse the issue body for form fields
const body = issue.body || '';
// Extract Discord ID — look for the numeric value after the "Discord User ID" heading
const discordMatch = body.match(/### Discord User ID\s*\n\s*(\d{17,20})/);
if (!discordMatch) {
await github.rest.issues.createComment({
...context.repo,
issue_number: issue.number,
body: `Could not find a valid Discord ID in the issue body. Please make sure you entered a numeric ID (17-20 digits), not a username.\n\nExample: \`123456789012345678\``
});
await github.rest.issues.update({
...context.repo,
issue_number: issue.number,
state: 'closed',
state_reason: 'not_planned'
});
return;
}
const discordId = discordMatch[1];
// Extract display name (optional)
const nameMatch = body.match(/### Display Name \(optional\)\s*\n\s*(.+)/);
const displayName = nameMatch ? nameMatch[1].trim() : '';
// Check if user already exists
const yml = fs.readFileSync('contributors.yml', 'utf-8');
if (yml.includes(`github: ${githubUsername}`)) {
await github.rest.issues.createComment({
...context.repo,
issue_number: issue.number,
body: `@${githubUsername} is already in \`contributors.yml\`. If you need to update your Discord ID, please edit the file directly via PR.`
});
await github.rest.issues.update({
...context.repo,
issue_number: issue.number,
state: 'closed',
state_reason: 'completed'
});
return;
}
// Append entry to contributors.yml
let entry = ` - github: ${githubUsername}\n discord: "${discordId}"`;
if (displayName && displayName !== '_No response_') {
entry += `\n name: ${displayName}`;
}
entry += '\n';
const updated = yml.trimEnd() + '\n' + entry;
fs.writeFileSync('contributors.yml', updated);
// Set outputs for commit step
core.exportVariable('GITHUB_USERNAME', githubUsername);
core.exportVariable('DISCORD_ID', discordId);
core.exportVariable('ISSUE_NUMBER', issue.number.toString());
- name: Create PR
run: |
# Check if there are changes
if git diff --quiet contributors.yml; then
echo "No changes to contributors.yml"
exit 0
fi
BRANCH="docs/link-discord-${GITHUB_USERNAME}"
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git checkout -b "$BRANCH"
git add contributors.yml
git commit -m "docs: link @${GITHUB_USERNAME} to Discord"
git push origin "$BRANCH"
gh pr create \
--title "docs: link @${GITHUB_USERNAME} to Discord" \
--body "Adds @${GITHUB_USERNAME} (Discord \`${DISCORD_ID}\`) to \`contributors.yml\` for bounty XP tracking.
Closes #${ISSUE_NUMBER}" \
--base main \
--head "$BRANCH" \
--label "link-discord"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Notify on issue
uses: actions/github-script@v7
with:
script: |
const username = process.env.GITHUB_USERNAME;
const issueNumber = parseInt(process.env.ISSUE_NUMBER);
await github.rest.issues.createComment({
...context.repo,
issue_number: issueNumber,
body: `A PR has been created to link your account. A maintainer will merge it shortly — once merged, you'll receive XP and Discord pings when your bounty PRs are merged.`
});
+2
View File
@@ -35,6 +35,8 @@ jobs:
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
BOT_API_URL: ${{ secrets.BOT_API_URL }}
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
SINCE_DATE: ${{ github.event.inputs.since_date || '' }}
+8 -2
View File
@@ -4,7 +4,7 @@
Welcome to Aden Hive, an open-source AI agent framework built for developers who demand production-grade reliability, cross-platform support, and real-world performance. This guide will help you contribute effectively, whether you're fixing bugs, adding features, improving documentation, or building new tools.
Thank you for your interest in contributing! We're especially looking for help building tools, integrations ([check #2805](https://github.com/adenhq/hive/issues/2805)), and example agents for the framework.
Thank you for your interest in contributing! We're especially looking for help building tools, integrations ([check #2805](https://github.com/aden-hive/hive/issues/2805)), and example agents for the framework.
---
@@ -390,6 +390,8 @@ Aden Hive supports **100+ LLM providers** via LiteLLM, giving users maximum flex
|----------|--------|-------|
| **Anthropic** | Claude 3.5 Sonnet, Haiku, Opus | Default provider, best for reasoning |
| **OpenAI** | GPT-4, GPT-4 Turbo, GPT-4o | Function calling, vision |
| **OpenRouter** | Any OpenRouter catalog model | Uses `OPENROUTER_API_KEY` and `https://openrouter.ai/api/v1` |
| **Hive LLM** | `queen`, `kimi-2.5`, `GLM-5` | Uses `HIVE_API_KEY` and the Hive-managed endpoint |
| **Google** | Gemini 1.5 Pro, Flash | Long context windows |
| **DeepSeek** | DeepSeek V3 | Cost-effective, strong reasoning |
| **Mistral** | Mistral Large, Medium, Small | Open weights, EU hosting |
@@ -415,6 +417,10 @@ DEFAULT_MODEL = "claude-haiku-4-5-20251001"
- **Cost**: DeepSeek or Gemini Flash (budget-conscious)
- **Privacy**: Ollama with local models (no data leaves server)
**Provider-Specific Notes**
- **OpenRouter**: store `provider` as `openrouter`, use the raw OpenRouter model ID in `model` (for example `x-ai/grok-4.20-beta`), and use `OPENROUTER_API_KEY`
- **Hive LLM**: store `provider` as `hive`, use Hive model names such as `queen`, `kimi-2.5`, or `GLM-5`, and use `HIVE_API_KEY`
**For Development**
- Use cheaper/faster models (Haiku, GPT-4o-mini)
- Test with multiple providers to catch provider-specific issues
@@ -426,7 +432,7 @@ DEFAULT_MODEL = "claude-haiku-4-5-20251001"
2. **Add credential handling** in `core/framework/credentials/`
3. **Add provider-specific configuration** in `core/framework/llm/`
4. **Write tests** in `core/tests/test_llm_provider.py`
5. **Update documentation** in `docs/llm_providers.md`
5. **Update documentation** in `README.md`, `docs/configuration.md`, and any setup guides that mention provider configuration
**Example: Testing LLM Integration**
+6 -5
View File
@@ -41,7 +41,8 @@ Generate a swarm of worker agents with a coding agent(queen) that control them.
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
https://github.com/user-attachments/assets/aad3a035-e7b3-4cac-b13d-4a83c7002c30
https://github.com/user-attachments/assets/bf10edc3-06ba-48b6-98ba-d069b15fb69d
## Who Is Hive For?
@@ -74,7 +75,7 @@ Use Hive when you need:
- **[Self-Hosting Guide](https://docs.adenhq.com/getting-started/quickstart)** - Deploy Hive on your infrastructure
- **[Changelog](https://github.com/aden-hive/hive/releases)** - Latest updates and releases
- **[Roadmap](docs/roadmap.md)** - Upcoming features and plans
- **[Report Issues](https://github.com/adenhq/hive/issues)** - Bug reports and feature requests
- **[Report Issues](https://github.com/aden-hive/hive/issues)** - Bug reports and feature requests
- **[Contributing](CONTRIBUTING.md)** - How to contribute and submit PRs
## Quick Start
@@ -109,7 +110,7 @@ This sets up:
- **framework** - Core agent runtime and graph executor (in `core/.venv`)
- **aden_tools** - MCP tools for agent capabilities (in `tools/.venv`)
- **credential store** - Encrypted API key storage (`~/.hive/credentials`)
- **LLM provider** - Interactive default model configuration
- **LLM provider** - Interactive default model configuration, including Hive LLM and OpenRouter
- All required Python dependencies with `uv`
- Finally, it will open the Hive interface in your browser
@@ -148,7 +149,7 @@ Now you can run an agent by selecting the agent (either an existing agent or exa
<a href="https://github.com/aden-hive/hive/tree/main/tools/src/aden_tools/tools"><img width="100%" alt="Integration" src="https://github.com/user-attachments/assets/a1573f93-cf02-4bb8-b3d5-b305b05b1e51" /></a>
Hive is built to be model-agnostic and system-agnostic.
- **LLM flexibility** - Hive Framework is designed to support various types of LLMs, including hosted and local models through LiteLLM-compatible providers.
- **LLM flexibility** - Hive Framework supports Anthropic, OpenAI, OpenRouter, Hive LLM, and other hosted or local models through LiteLLM-compatible providers.
- **Business system connectivity** - Hive Framework is designed to connect to all kinds of business systems as tools, such as CRM, support, messaging, data, file, and internal APIs via MCP.
## Why Aden
@@ -376,7 +377,7 @@ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENS
**Q: What LLM providers does Hive support?**
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name. We recommend using Claude, GLM and Gemini as they have the best performance.
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, OpenRouter, and Hive LLM. Simply set the appropriate API key environment variable and specify the model name. See [docs/configuration.md](docs/configuration.md) for provider-specific configuration examples.
**Q: Can I use Hive with local AI models like Ollama?**
-27
View File
@@ -1,27 +0,0 @@
# Identity mapping: GitHub username -> Discord ID
#
# This file links GitHub accounts to Discord accounts for the
# Integration Bounty Program. When a bounty PR is merged, the
# GitHub Action uses this file to ping the contributor on Discord.
#
# HOW TO ADD YOURSELF:
# Open a "Link Discord Account" issue:
# https://github.com/aden-hive/hive/issues/new?template=link-discord.yml
# A GitHub Action will automatically add your entry here.
#
# To find your Discord ID:
# 1. Open Discord Settings > Advanced > Enable Developer Mode
# 2. Right-click your name > Copy User ID
#
# Format:
# - github: your-github-username
# discord: "your-discord-id" # quotes required (it's a number)
# name: Your Display Name # optional
contributors:
# - github: example-user
# discord: "123456789012345678"
# name: Example User
- github: TimothyZhang7
discord: "408460790061072384"
name: Timothy@Aden
+583
View File
@@ -0,0 +1,583 @@
#!/usr/bin/env python3
"""Antigravity authentication CLI.
Implements OAuth2 flow for Google's Antigravity Code Assist gateway.
Credentials are stored in ~/.hive/antigravity-accounts.json.
Usage:
python -m antigravity_auth auth account add
python -m antigravity_auth auth account list
python -m antigravity_auth auth account remove <email>
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import secrets
import socket
import sys
import time
import urllib.parse
import urllib.request
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
# OAuth endpoints
_OAUTH_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
# Scopes for Antigravity/Cloud Code Assist
_OAUTH_SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
]
# Credentials file path in ~/.hive/
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
# Default project ID
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
_DEFAULT_REDIRECT_PORT = 51121
# OAuth credentials fetched from the opencode-antigravity-auth project.
# This project reverse-engineered and published the public OAuth credentials
# for Google's Antigravity/Cloud Code Assist API.
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
_CREDENTIALS_URL = (
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
)
# Cached credentials fetched from public source
_cached_client_id: str | None = None
_cached_client_secret: str | None = None
def _fetch_credentials_from_public_source() -> tuple[str | None, str | None]:
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
global _cached_client_id, _cached_client_secret
if _cached_client_id and _cached_client_secret:
return _cached_client_id, _cached_client_secret
try:
req = urllib.request.Request(
_CREDENTIALS_URL, headers={"User-Agent": "Hive-Antigravity-Auth/1.0"}
)
with urllib.request.urlopen(req, timeout=10) as resp:
content = resp.read().decode("utf-8")
import re
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
if id_match:
_cached_client_id = id_match.group(1)
if secret_match:
_cached_client_secret = secret_match.group(1)
return _cached_client_id, _cached_client_secret
except Exception as e:
logger.debug(f"Failed to fetch credentials from public source: {e}")
return None, None
def get_client_id() -> str:
"""Get OAuth client ID from env, config, or public source."""
env_id = os.environ.get("ANTIGRAVITY_CLIENT_ID")
if env_id:
return env_id
# Try hive config
hive_cfg = Path.home() / ".hive" / "configuration.json"
if hive_cfg.exists():
try:
with open(hive_cfg) as f:
cfg = json.load(f)
cfg_id = cfg.get("llm", {}).get("antigravity_client_id")
if cfg_id:
return cfg_id
except Exception:
pass
# Fetch from public source
client_id, _ = _fetch_credentials_from_public_source()
if client_id:
return client_id
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
def get_client_secret() -> str | None:
"""Get OAuth client secret from env, config, or public source."""
secret = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
if secret:
return secret
# Try to read from hive config
hive_cfg = Path.home() / ".hive" / "configuration.json"
if hive_cfg.exists():
try:
with open(hive_cfg) as f:
cfg = json.load(f)
secret = cfg.get("llm", {}).get("antigravity_client_secret")
if secret:
return secret
except Exception:
pass
# Fetch from public source (npm package on GitHub)
_, secret = _fetch_credentials_from_public_source()
return secret
def find_free_port() -> int:
"""Find an available local port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
return s.getsockname()[1]
class OAuthCallbackHandler(BaseHTTPRequestHandler):
"""Handle OAuth callback from browser."""
auth_code: str | None = None
state: str | None = None
error: str | None = None
def log_message(self, format: str, *args: Any) -> None:
pass # Suppress default logging
def do_GET(self) -> None:
parsed = urllib.parse.urlparse(self.path)
if parsed.path == "/oauth-callback":
query = urllib.parse.parse_qs(parsed.query)
if "error" in query:
self.error = query["error"][0]
self._send_response("Authentication failed. You can close this window.")
return
if "code" in query and "state" in query:
OAuthCallbackHandler.auth_code = query["code"][0]
OAuthCallbackHandler.state = query["state"][0]
self._send_response(
"Authentication successful! You can close this window "
"and return to the terminal."
)
return
self._send_response("Waiting for authentication...")
def _send_response(self, message: str) -> None:
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
html = f"""<!DOCTYPE html>
<html>
<head><title>Antigravity Auth</title></head>
<body style="font-family: system-ui; display: flex; align-items: center;
justify-content: center; height: 100vh; margin: 0; background: #1a1a2e;
color: #eee;">
<div style="text-align: center;">
<h2>{message}</h2>
</div>
</body>
</html>"""
self.wfile.write(html.encode())
def wait_for_callback(port: int, timeout: int = 300) -> tuple[str | None, str | None, str | None]:
"""Start local server and wait for OAuth callback."""
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
server.timeout = 1
start = time.time()
while time.time() - start < timeout:
if OAuthCallbackHandler.auth_code:
return (
OAuthCallbackHandler.auth_code,
OAuthCallbackHandler.state,
OAuthCallbackHandler.error,
)
server.handle_request()
return None, None, "timeout"
def exchange_code_for_tokens(
code: str, redirect_uri: str, client_id: str, client_secret: str | None
) -> dict[str, Any] | None:
"""Exchange authorization code for tokens."""
data = {
"code": code,
"client_id": client_id,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
}
if client_secret:
data["client_secret"] = client_secret
body = urllib.parse.urlencode(data).encode()
req = urllib.request.Request(
_OAUTH_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
except Exception as e:
logger.error(f"Token exchange failed: {e}")
return None
def get_user_email(access_token: str) -> str | None:
"""Get user email from Google API."""
req = urllib.request.Request(
"https://www.googleapis.com/oauth2/v2/userinfo",
headers={"Authorization": f"Bearer {access_token}"},
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return data.get("email")
except Exception:
return None
def load_accounts() -> dict[str, Any]:
"""Load existing accounts from file."""
if not _ACCOUNTS_FILE.exists():
return {"schemaVersion": 4, "accounts": []}
try:
with open(_ACCOUNTS_FILE) as f:
return json.load(f)
except Exception:
return {"schemaVersion": 4, "accounts": []}
def save_accounts(data: dict[str, Any]) -> None:
"""Save accounts to file."""
_ACCOUNTS_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(_ACCOUNTS_FILE, "w") as f:
json.dump(data, f, indent=2)
logger.info(f"Saved credentials to {_ACCOUNTS_FILE}")
def validate_credentials(access_token: str, project_id: str = _DEFAULT_PROJECT_ID) -> bool:
"""Test if credentials work by making a simple API call to Antigravity.
Returns True if credentials are valid, False otherwise.
"""
endpoint = "https://daily-cloudcode-pa.sandbox.googleapis.com"
body = {
"project": project_id,
"model": "gemini-3-flash",
"request": {
"contents": [{"role": "user", "parts": [{"text": "hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
},
"requestType": "agent",
"userAgent": "antigravity",
"requestId": "validation-test",
}
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) Antigravity/1.18.3"
),
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
}
try:
req = urllib.request.Request(
f"{endpoint}/v1internal:generateContent",
data=json.dumps(body).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(req, timeout=30) as resp:
json.loads(resp.read())
return True
except Exception:
return False
def refresh_access_token(
refresh_token: str, client_id: str, client_secret: str | None
) -> dict | None:
"""Refresh the access token using the refresh token."""
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
data["client_secret"] = client_secret
body = urllib.parse.urlencode(data).encode()
req = urllib.request.Request(
_OAUTH_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
except Exception as e:
logger.debug(f"Token refresh failed: {e}")
return None
def cmd_account_add(args: argparse.Namespace) -> int:
"""Add a new Antigravity account via OAuth2.
First checks if valid credentials already exist. If so, validates them
and skips OAuth if they work. Otherwise, proceeds with OAuth flow.
"""
client_id = get_client_id()
client_secret = get_client_secret()
# Check if credentials already exist
accounts_data = load_accounts()
accounts = accounts_data.get("accounts", [])
if accounts:
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
access_token = account.get("access")
refresh_token_str = account.get("refresh", "")
refresh_token = refresh_token_str.split("|")[0] if refresh_token_str else None
project_id = (
refresh_token_str.split("|")[1] if "|" in refresh_token_str else _DEFAULT_PROJECT_ID
)
email = account.get("email", "unknown")
expires_ms = account.get("expires", 0)
expires_at = expires_ms / 1000.0 if expires_ms else 0.0
# Check if token is expired or near expiry
if access_token and expires_at and time.time() < expires_at - 60:
# Token still valid, test it
logger.info(f"Found existing credentials for: {email}")
logger.info("Validating existing credentials...")
if validate_credentials(access_token, project_id):
logger.info("✓ Credentials valid! Skipping OAuth.")
return 0
else:
logger.info("Credentials failed validation, refreshing...")
elif refresh_token:
logger.info(f"Found expired credentials for: {email}")
logger.info("Attempting token refresh...")
tokens = refresh_access_token(refresh_token, client_id, client_secret)
if tokens:
new_access = tokens.get("access_token")
expires_in = tokens.get("expires_in", 3600)
if new_access:
# Update the account
account["access"] = new_access
account["expires"] = int((time.time() + expires_in) * 1000)
accounts_data["last_refresh"] = time.strftime(
"%Y-%m-%dT%H:%M:%SZ", time.gmtime()
)
save_accounts(accounts_data)
# Validate the refreshed token
logger.info("Validating refreshed credentials...")
if validate_credentials(new_access, project_id):
logger.info("✓ Credentials refreshed and validated!")
return 0
else:
logger.info("Refreshed token failed validation, proceeding with OAuth...")
else:
logger.info("Token refresh failed, proceeding with OAuth...")
# No valid credentials, proceed with OAuth
if not client_secret:
logger.warning(
"No client secret configured. Token refresh may fail.\n"
"Set ANTIGRAVITY_CLIENT_SECRET env var or add "
"'antigravity_client_secret' to ~/.hive/configuration.json"
)
# Use fixed port and path matching Google's expected OAuth redirect URI
port = _DEFAULT_REDIRECT_PORT
redirect_uri = f"http://localhost:{port}/oauth-callback"
# Generate state for CSRF protection
state = secrets.token_urlsafe(16)
# Build authorization URL
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(_OAUTH_SCOPES),
"state": state,
"access_type": "offline",
"prompt": "consent",
}
auth_url = f"{_OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
logger.info("Opening browser for authentication...")
logger.info(f"If the browser doesn't open, visit: {auth_url}\n")
# Open browser
webbrowser.open(auth_url)
# Wait for callback
logger.info(f"Listening for callback on port {port}...")
code, received_state, error = wait_for_callback(port)
if error:
logger.error(f"Authentication failed: {error}")
return 1
if not code:
logger.error("No authorization code received")
return 1
if received_state != state:
logger.error("State mismatch - possible CSRF attack")
return 1
# Exchange code for tokens
logger.info("Exchanging authorization code for tokens...")
tokens = exchange_code_for_tokens(code, redirect_uri, client_id, client_secret)
if not tokens:
return 1
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
expires_in = tokens.get("expires_in", 3600)
if not access_token:
logger.error("No access token in response")
return 1
# Get user email
email = get_user_email(access_token)
if email:
logger.info(f"Authenticated as: {email}")
# Load existing accounts and add/update
accounts_data = load_accounts()
accounts = accounts_data.get("accounts", [])
# Build new account entry (V4 schema)
expires_ms = int((time.time() + expires_in) * 1000)
refresh_entry = f"{refresh_token}|{_DEFAULT_PROJECT_ID}"
new_account = {
"access": access_token,
"refresh": refresh_entry,
"expires": expires_ms,
"email": email,
"enabled": True,
}
# Update existing account or add new one
existing_idx = next((i for i, a in enumerate(accounts) if a.get("email") == email), None)
if existing_idx is not None:
accounts[existing_idx] = new_account
logger.info(f"Updated existing account: {email}")
else:
accounts.append(new_account)
logger.info(f"Added new account: {email}")
accounts_data["accounts"] = accounts
accounts_data["schemaVersion"] = 4
accounts_data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
save_accounts(accounts_data)
logger.info("\n✓ Authentication complete!")
return 0
def cmd_account_list(args: argparse.Namespace) -> int:
"""List all stored accounts."""
data = load_accounts()
accounts = data.get("accounts", [])
if not accounts:
logger.info("No accounts configured.")
logger.info("Run 'antigravity auth account add' to add one.")
return 0
logger.info("Configured accounts:\n")
for i, account in enumerate(accounts, 1):
email = account.get("email", "unknown")
enabled = "enabled" if account.get("enabled", True) else "disabled"
logger.info(f" {i}. {email} ({enabled})")
return 0
def cmd_account_remove(args: argparse.Namespace) -> int:
"""Remove an account by email."""
email = args.email
data = load_accounts()
accounts = data.get("accounts", [])
original_len = len(accounts)
accounts = [a for a in accounts if a.get("email") != email]
if len(accounts) == original_len:
logger.error(f"No account found with email: {email}")
return 1
data["accounts"] = accounts
save_accounts(data)
logger.info(f"Removed account: {email}")
return 0
def main() -> int:
parser = argparse.ArgumentParser(
description="Antigravity authentication CLI",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="command", help="Commands")
# auth account add
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
auth_subparsers = auth_parser.add_subparsers(dest="auth_command")
account_parser = auth_subparsers.add_parser("account", help="Account management")
account_subparsers = account_parser.add_subparsers(dest="account_command")
add_parser = account_subparsers.add_parser("add", help="Add a new account via OAuth2")
add_parser.set_defaults(func=cmd_account_add)
list_parser = account_subparsers.add_parser("list", help="List configured accounts")
list_parser.set_defaults(func=cmd_account_list)
remove_parser = account_subparsers.add_parser("remove", help="Remove an account")
remove_parser.add_argument("email", help="Email of account to remove")
remove_parser.set_defaults(func=cmd_account_remove)
args = parser.parse_args()
if hasattr(args, "func"):
return args.func(args)
parser.print_help()
return 0
if __name__ == "__main__":
sys.exit(main())
+81 -27
View File
@@ -17,6 +17,7 @@ import http.server
import json
import os
import platform
import queue
import secrets
import subprocess
import sys
@@ -27,6 +28,7 @@ import urllib.parse
import urllib.request
from datetime import UTC, datetime
from pathlib import Path
from typing import TextIO
# OAuth constants (from the Codex CLI binary)
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
@@ -165,11 +167,11 @@ def open_browser(url: str) -> bool:
if system == "Darwin":
subprocess.Popen(["open", url], stdout=devnull, stderr=devnull)
elif system == "Windows":
subprocess.Popen(["cmd", "/c", "start", url], stdout=devnull, stderr=devnull)
os.startfile(url) # type: ignore[attr-defined]
else:
subprocess.Popen(["xdg-open", url], stdout=devnull, stderr=devnull)
return True
except OSError:
except (AttributeError, OSError):
return False
@@ -266,6 +268,71 @@ def parse_manual_input(value: str, expected_state: str) -> str | None:
return None
def _read_manual_input_lines(
manual_inputs: queue.Queue[str],
stop_event: threading.Event,
stdin: TextIO | None = None,
) -> None:
stream = sys.stdin if stdin is None else stdin
while not stop_event.is_set():
try:
manual = stream.readline()
except (EOFError, OSError):
return
if not manual:
return
if manual.strip():
manual_inputs.put(manual)
def wait_for_code_from_callback_or_stdin(
expected_state: str,
callback_result: list[str | None],
callback_done: threading.Event,
timeout_secs: float = 120,
poll_interval: float = 0.1,
stdin: TextIO | None = None,
) -> str | None:
manual_inputs: queue.Queue[str] = queue.Queue()
stop_event = threading.Event()
# Read stdin on a daemon thread so manual paste works on platforms where
# select() cannot poll console handles, including Windows terminals.
threading.Thread(
target=_read_manual_input_lines,
args=(manual_inputs, stop_event, stdin),
daemon=True,
).start()
deadline = time.time() + timeout_secs
try:
while time.time() < deadline:
if callback_result[0]:
return callback_result[0]
while True:
try:
manual = manual_inputs.get_nowait()
except queue.Empty:
break
code = parse_manual_input(manual, expected_state)
if code:
return code
if callback_done.is_set():
return callback_result[0]
time.sleep(poll_interval)
return callback_result[0]
finally:
stop_event.set()
def main() -> int:
# Generate PKCE and state
verifier, challenge = generate_pkce()
@@ -315,41 +382,28 @@ def main() -> int:
# Start callback server in background
callback_result: list[str | None] = [None]
callback_done = threading.Event()
def run_server() -> None:
callback_result[0] = wait_for_callback(state, timeout_secs=120)
try:
callback_result[0] = wait_for_callback(state, timeout_secs=120)
finally:
callback_done.set()
server_thread = threading.Thread(target=run_server)
server_thread.daemon = True
server_thread.start()
# Also accept manual input in parallel
# We poll for both the server result and stdin
try:
import select
while server_thread.is_alive():
# Check if stdin has data (non-blocking on unix)
if hasattr(select, "select"):
ready, _, _ = select.select([sys.stdin], [], [], 0.5)
if ready:
manual = sys.stdin.readline()
if manual.strip():
code = parse_manual_input(manual, state)
if code:
break
else:
time.sleep(0.5)
if callback_result[0]:
code = callback_result[0]
break
except (KeyboardInterrupt, EOFError):
code = wait_for_code_from_callback_or_stdin(
state,
callback_result,
callback_done,
timeout_secs=120,
)
except KeyboardInterrupt:
print("\n\033[0;31mCancelled.\033[0m")
return 1
if not code:
code = callback_result[0]
else:
# Manual paste mode
try:
+1 -1
View File
@@ -79,7 +79,7 @@ async def example_3_config_file():
# Copy example config (in practice, you'd place this in your agent folder)
import shutil
shutil.copy("examples/mcp_servers.json", test_agent_path / "mcp_servers.json")
shutil.copy(Path(__file__).parent / "mcp_servers.json", test_agent_path / "mcp_servers.json")
# Load agent - MCP servers will be auto-discovered
runner = AgentRunner.load(test_agent_path)
@@ -16,6 +16,7 @@ after the user picks an account programmatically.
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
@@ -25,6 +26,7 @@ from framework.graph.checkpoint_config import CheckpointConfig
from framework.graph.edge import GraphSpec
from framework.graph.executor import ExecutionResult
from framework.llm import LiteLLMProvider
from framework.runner.mcp_registry import MCPRegistry
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
from framework.runtime.execution_stream import EntryPointSpec
@@ -32,9 +34,13 @@ from framework.runtime.execution_stream import EntryPointSpec
from .config import default_config
from .nodes import build_tester_node
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from framework.runner import AgentRunner
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Goal
# ---------------------------------------------------------------------------
@@ -107,7 +113,11 @@ def _list_aden_accounts() -> list[dict]:
for c in integrations
if c.status == "active"
]
except (ImportError, OSError) as exc:
logger.debug("Could not list Aden accounts: %s", exc)
return []
except Exception:
logger.warning("Unexpected error listing Aden accounts", exc_info=True)
return []
@@ -119,7 +129,11 @@ def _list_local_accounts() -> list[dict]:
return [
info.to_account_dict() for info in LocalCredentialRegistry.default().list_accounts()
]
except ImportError as exc:
logger.debug("Local credential registry unavailable: %s", exc)
return []
except Exception:
logger.warning("Unexpected error listing local accounts", exc_info=True)
return []
@@ -140,7 +154,11 @@ def _list_env_fallback_accounts() -> list[dict]:
from framework.credentials.storage import EncryptedFileStorage
encrypted_ids: set[str] = set(EncryptedFileStorage().list_all())
except (ImportError, OSError) as exc:
logger.debug("Could not read encrypted store: %s", exc)
encrypted_ids = set()
except Exception:
logger.warning("Unexpected error reading encrypted store", exc_info=True)
encrypted_ids = set()
def _is_configured(cred_name: str, spec) -> bool:
@@ -300,8 +318,10 @@ def _activate_local_account(credential_id: str, alias: str) -> None:
if key:
os.environ[spec.env_var] = key
except (ImportError, KeyError, OSError) as exc:
logger.debug("Could not inject credentials: %s", exc)
except Exception:
pass
logger.warning("Unexpected error injecting credentials", exc_info=True)
def _configure_aden_node(
@@ -563,6 +583,15 @@ class CredentialTesterAgent:
if mcp_config_path.exists():
self._tool_registry.load_mcp_config(mcp_config_path)
try:
registry = MCPRegistry()
registry.initialize()
registry_configs = registry.load_agent_selection(Path(__file__).parent)
if registry_configs:
self._tool_registry.load_registry_servers(registry_configs)
except Exception:
logger.warning("MCP registry config failed to load", exc_info=True)
extra_kwargs = getattr(self.config, "extra_kwargs", {}) or {}
llm = LiteLLMProvider(
model=self.config.model,
@@ -702,6 +702,15 @@ stop_worker() to return to STAGING phase.
_queen_behavior_always = """
# Behavior
## Images attached by the user
Users can attach images directly to their chat messages. When you see an \
image in the conversation, analyze it using your native vision capability \
do NOT say you cannot see images or that you lack access to files. The image \
is embedded in the message; no tool call is needed to view it. Describe what \
you see, answer questions about it, and use the visual content to inform your \
response just as you would text.
## CRITICAL RULE — ask_user / ask_user_multiple
Every response that ends with a question, a prompt, or expects user \
+107 -3
View File
@@ -137,6 +137,32 @@ def append_episodic_entry(content: str) -> None:
with ep_path.open("a", encoding="utf-8") as f:
f.write(block)
# Immediately create a bare index entry (no enrichment — that happens at
# consolidation time). Wrapped so any indexing failure never interrupts
# the diary write.
try:
_post_append_index_hook(today.strftime("%Y-%m-%d"), timestamp, content.strip())
except Exception:
logger.warning("queen_memory: index hook failed on diary append", exc_info=True)
def _post_append_index_hook(date_str: str, timestamp: str, prose: str) -> None:
"""Create a bare MemoryEntry in the index for a freshly-appended diary section."""
from framework.agents.queen.queen_memory_index import (
get_entry,
index_entry_from_diary_section,
load_index,
put_entry,
save_index,
)
index = load_index()
entry_id = f"{date_str}:{timestamp}"
if get_entry(index, entry_id) is None:
entry = index_entry_from_diary_section(date_str, timestamp, prose)
put_entry(index, entry)
save_index(index)
def seed_if_missing() -> None:
"""Create MEMORY.md with a blank template if it doesn't exist yet."""
@@ -226,7 +252,11 @@ def read_session_context(session_dir: Path, max_messages: int = 80) -> str:
elif content:
label = "user" if role == "user" else "queen"
lines.append(f"[{label}]: {content[:600]}")
except (KeyError, TypeError) as exc:
logger.debug("Skipping malformed conversation message: %s", exc)
continue
except Exception:
logger.warning("Unexpected error parsing conversation message", exc_info=True)
continue
if lines:
parts.append("## Conversation\n\n" + "\n".join(lines))
@@ -307,9 +337,10 @@ async def consolidate_queen_memory(
llm: LLMProvider instance (must support acomplete()).
"""
try:
logger.info("queen_memory: consolidation triggered for session %s", session_id)
session_context = read_session_context(session_dir)
if not session_context:
logger.debug("queen_memory: no session context, skipping consolidation")
logger.info("queen_memory: no session context found, skipping")
return
logger.info("queen_memory: consolidating memory for session %s ...", session_id)
@@ -384,6 +415,14 @@ async def consolidate_queen_memory(
len(diary_entry),
)
# Update the memory index for today's entries: enrich, embed, link,
# and optionally evolve neighbour metadata. Wrapped so failures never
# block or disrupt the main consolidation path.
try:
await _update_index_after_consolidation(today.strftime("%Y-%m-%d"), llm)
except Exception:
logger.warning("queen_memory: index update failed after consolidation", exc_info=True)
except Exception:
tb = traceback.format_exc()
logger.exception("queen_memory: consolidation failed")
@@ -395,5 +434,70 @@ async def consolidate_queen_memory(
f"session: {session_id}\ntime: {datetime.now().isoformat()}\n\n{tb}",
encoding="utf-8",
)
except Exception:
pass
except OSError:
pass # Cannot write error file; original exception already logged
async def _update_index_after_consolidation(date_str: str, llm: object) -> None:
"""Enrich, embed, link, and evolve today's memory index entries.
Called after the main semantic/diary LLM writes complete. All failures
are silently logged this function must never propagate exceptions.
"""
from framework.agents.queen.queen_memory_index import (
embed_text,
embeddings_enabled,
get_embed_model,
link_entry,
load_index,
maybe_evolve_neighbors,
put_entry,
rebuild_index_for_date,
save_index,
)
# Phase 1 — ensure all diary sections are in the index and enriched
await rebuild_index_for_date(date_str, llm=llm)
if not embeddings_enabled():
logger.debug("queen_memory: embeddings not configured, skipping embed/link/evolve")
return # Phases 2-5 require embeddings
logger.info("queen_memory: running embed/link/evolve for %s", date_str)
# Phases 2-5 — embed, link, evolve any entries still missing vectors
index = load_index()
entries = index.get("entries", {})
newly_embedded: list[str] = []
for entry_id, raw in entries.items():
if not entry_id.startswith(date_str):
continue
if raw.get("embedding") is not None:
continue
prose = raw.get("summary", "")
if not prose:
continue
vec = await embed_text(prose)
if vec is not None:
raw["embedding"] = vec
index["embed_model"] = get_embed_model()
index["embed_dim"] = len(vec)
newly_embedded.append(entry_id)
if newly_embedded:
save_index(index)
# Phase 3 — cross-reference linking for newly embedded entries
for entry_id in newly_embedded:
linked = link_entry(index, entry_id)
# Phase 5 — memory evolution for top neighbours
if linked:
await maybe_evolve_neighbors(entry_id, linked, index, llm)
if newly_embedded:
save_index(index)
logger.debug(
"queen_memory: indexed %d new embedding(s) for %s",
len(newly_embedded),
date_str,
)
@@ -0,0 +1,788 @@
"""Structured index for queen episodic memory entries.
Attaches rich metadata, embedding vectors, cross-reference links, and
retrieval counts to every diary entry. The index lives at:
~/.hive/queen/memories/index.json
It is a *sidecar* to the existing markdown diary files those files are
never modified by this module.
Configuration
-------------
Set ``HIVE_EMBED_MODEL`` to an embedding model name supported by litellm
(e.g. ``text-embedding-3-small``) to enable semantic search. When unset
the system degrades gracefully: enrichment (keywords/tags/category) still
works via the consolidation LLM, and recall_diary falls back to substring
matching.
Phases implemented
------------------
Phase 1 - Index I/O + semantic enrichment (keywords, category, tags)
Phase 2 - Embedding storage + semantic search via cosine similarity
Phase 3 - Cross-reference linking (bidirectional related[] links)
Phase 4 - Importance tracking (retrieval counts + recency decay)
Phase 5 - Memory evolution (LLM-driven neighbour metadata refinement)
"""
from __future__ import annotations
import json
import logging
import math
import re
from dataclasses import asdict, dataclass, field
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Category vocabulary — fixed to prevent unbounded drift
# ---------------------------------------------------------------------------
_CATEGORIES = [
"agent_build",
"infrastructure",
"user_preference",
"communication_style",
"diagnostic_learning",
"milestone",
"pipeline",
"data_processing",
"other",
]
# ---------------------------------------------------------------------------
# MemoryEntry dataclass
# ---------------------------------------------------------------------------
@dataclass
class MemoryEntry:
"""Rich metadata record for a single diary section (one ### HH:MM block)."""
# Identity — "YYYY-MM-DD:HH:MM" matches the diary ### timestamp
id: str
date: str # "YYYY-MM-DD"
timestamp: str # "HH:MM"
# Content preview (not full prose — just enough for search result context)
summary: str # first 300 chars of the section's prose
# Phase 1 — semantic enrichment
keywords: list[str] = field(default_factory=list)
category: str = "other"
tags: list[str] = field(default_factory=list)
# Phase 3 — cross-reference links
related: list[str] = field(default_factory=list)
# Phase 4 — importance tracking
retrieval_count: int = 0
last_retrieved: str | None = None # ISO-format datetime string
# Phase 2 — embedding vector (None when HIVE_EMBED_MODEL is unset)
embedding: list[float] | None = None
# Whether enrichment has been applied (used to skip re-enrichment)
enriched: bool = False
# ---------------------------------------------------------------------------
# Index I/O
# ---------------------------------------------------------------------------
_EMPTY_INDEX: dict[str, Any] = {
"version": 1,
"embed_model": None,
"embed_dim": None,
"entries": {},
}
def _queen_memories_dir() -> Path:
return Path.home() / ".hive" / "queen" / "memories"
def index_path() -> Path:
return _queen_memories_dir() / "index.json"
def load_index() -> dict[str, Any]:
"""Load the index from disk. Returns a fresh empty index on any error."""
p = index_path()
if not p.exists():
return {**_EMPTY_INDEX, "entries": {}}
try:
data = json.loads(p.read_text(encoding="utf-8"))
if not isinstance(data, dict) or "entries" not in data:
raise ValueError("Malformed index")
return data
except Exception as exc:
logger.warning("queen_memory_index: index.json unreadable (%s), starting fresh", exc)
return {**_EMPTY_INDEX, "entries": {}}
def save_index(index: dict[str, Any]) -> None:
"""Atomically write the index to disk (tmp file → rename)."""
p = index_path()
p.parent.mkdir(parents=True, exist_ok=True)
tmp = p.with_suffix(".json.tmp")
tmp.write_text(json.dumps(index, ensure_ascii=False), encoding="utf-8")
tmp.replace(p)
def get_entry(index: dict[str, Any], entry_id: str) -> MemoryEntry | None:
"""Deserialise one entry from the index dict, or None if missing."""
raw = index.get("entries", {}).get(entry_id)
if raw is None:
return None
try:
return MemoryEntry(**{k: raw[k] for k in MemoryEntry.__dataclass_fields__ if k in raw})
except Exception as exc:
logger.warning("queen_memory_index: failed to deserialise entry %s: %s", entry_id, exc)
return None
def put_entry(index: dict[str, Any], entry: MemoryEntry) -> None:
"""Serialise and insert/overwrite one entry in the index dict (mutates in place)."""
index.setdefault("entries", {})[entry.id] = asdict(entry)
# ---------------------------------------------------------------------------
# Configuration helpers
# ---------------------------------------------------------------------------
def get_embed_model() -> str | None:
"""Return the configured embedding model (e.g. 'openai/text-embedding-3-small').
Reads from the ``embedding`` section of ~/.hive/configuration.json.
Falls back to the ``HIVE_EMBED_MODEL`` env var for backward compatibility.
"""
from framework.config import get_embed_model as _cfg_get_embed_model
return _cfg_get_embed_model()
def embeddings_enabled() -> bool:
return bool(get_embed_model())
def _detect_model_change(index: dict[str, Any]) -> bool:
"""Return True if the stored embed model differs from the current env var."""
current = get_embed_model()
stored = index.get("embed_model")
return current != stored
def _clear_embeddings(index: dict[str, Any]) -> None:
"""Clear all stored vectors when the embedding model has changed."""
for raw in index.get("entries", {}).values():
raw["embedding"] = None
index["embed_model"] = get_embed_model()
index["embed_dim"] = None
logger.info("queen_memory_index: embedding model changed — cleared cached vectors")
# ---------------------------------------------------------------------------
# Embedding calls (Phase 2)
# ---------------------------------------------------------------------------
def _embed_kwargs() -> dict[str, Any]:
"""Build the kwargs dict for litellm.aembedding() from configuration."""
from framework.config import get_embed_api_base, get_embed_api_key
kwargs: dict[str, Any] = {}
api_key = get_embed_api_key()
if api_key:
kwargs["api_key"] = api_key
api_base = get_embed_api_base()
if api_base:
kwargs["api_base"] = api_base
return kwargs
async def embed_text(text: str) -> list[float] | None:
"""Embed *text* via litellm.aembedding().
Returns None (with a WARNING log) on any failure or when no embedding
model is configured.
"""
model = get_embed_model()
if not model:
return None
try:
import litellm # already a project dependency
logger.info("queen_memory_index: embedding text (%d chars) via %s", len(text), model)
resp = await litellm.aembedding(model=model, input=[text], **_embed_kwargs())
vec: list[float] = resp.data[0]["embedding"]
logger.info("queen_memory_index: embedding complete (dim=%d)", len(vec))
return vec
except Exception as exc:
logger.warning("queen_memory_index: embed_text failed (%s)", exc)
return None
async def embed_batch(texts: list[str]) -> list[list[float] | None]:
"""Embed a list of texts, returning a parallel list of vectors (or None)."""
model = get_embed_model()
if not model:
return [None] * len(texts)
try:
import litellm
logger.info(
"queen_memory_index: batch embedding %d text(s) via %s", len(texts), model
)
resp = await litellm.aembedding(model=model, input=texts, **_embed_kwargs())
vecs = [item["embedding"] for item in resp.data]
logger.info(
"queen_memory_index: batch embedding complete (dim=%d)", len(vecs[0]) if vecs else 0
)
return vecs
except Exception as exc:
logger.warning("queen_memory_index: embed_batch failed (%s), retrying individually", exc)
# Fall back to individual calls
results: list[list[float] | None] = []
for t in texts:
results.append(await embed_text(t))
return results
# ---------------------------------------------------------------------------
# Vector math (Phase 2)
# ---------------------------------------------------------------------------
def cosine_similarity(a: list[float] | None, b: list[float] | None) -> float:
"""Return cosine similarity in [0, 1]. Returns 0.0 on null or zero-norm inputs."""
if not a or not b:
return 0.0
try:
import numpy as np # already a project dependency
va = np.array(a, dtype=np.float32)
vb = np.array(b, dtype=np.float32)
norm_a = float(np.linalg.norm(va))
norm_b = float(np.linalg.norm(vb))
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return float(np.dot(va, vb) / (norm_a * norm_b))
except Exception:
return 0.0
def find_knn(
query_vec: list[float],
index: dict[str, Any],
k: int = 5,
exclude_id: str | None = None,
) -> list[tuple[str, float]]:
"""Return up to *k* nearest neighbours as (entry_id, similarity) pairs, descending."""
scores: list[tuple[str, float]] = []
for entry_id, raw in index.get("entries", {}).items():
if entry_id == exclude_id:
continue
vec = raw.get("embedding")
if not vec:
continue
sim = cosine_similarity(query_vec, vec)
scores.append((entry_id, sim))
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:k]
# ---------------------------------------------------------------------------
# Semantic search (Phase 2)
# ---------------------------------------------------------------------------
async def semantic_search(
query: str,
index: dict[str, Any],
*,
k: int = 20,
date_range: tuple[str, str] | None = None,
) -> list[tuple[str, float]]:
"""Embed *query* and return top-k (entry_id, score) pairs.
Returns [] if embeddings are disabled or the embed call fails.
date_range is an inclusive (YYYY-MM-DD, YYYY-MM-DD) filter applied
before ranking.
"""
if not embeddings_enabled():
return []
query_vec = await embed_text(query)
if query_vec is None:
return []
candidates: list[tuple[str, float]] = []
for entry_id, raw in index.get("entries", {}).items():
if date_range:
d = raw.get("date", "")
if d < date_range[0] or d > date_range[1]:
continue
vec = raw.get("embedding")
if not vec:
continue
sim = cosine_similarity(query_vec, vec)
candidates.append((entry_id, sim))
candidates.sort(key=lambda x: x[1], reverse=True)
return candidates[:k]
# ---------------------------------------------------------------------------
# Importance tracking (Phase 4)
# ---------------------------------------------------------------------------
def importance_score(entry: MemoryEntry, now: datetime | None = None) -> float:
"""Composite importance: log1p(count) * recency decay (half-life 30 days).
Returns 0.0 for entries that have never been retrieved.
"""
if entry.retrieval_count == 0:
return 0.0
count_score = math.log1p(entry.retrieval_count)
if entry.last_retrieved:
try:
last = datetime.fromisoformat(entry.last_retrieved)
days_since = ((now or datetime.now()) - last).total_seconds() / 86400
decay = math.exp(-days_since / 30)
except ValueError:
decay = 0.0
else:
decay = 0.0
return count_score * decay
def record_retrieval(
index: dict[str, Any],
entry_ids: list[str],
*,
auto_save: bool = True,
) -> None:
"""Increment retrieval_count and update last_retrieved for each entry_id."""
now_str = datetime.now().isoformat()
entries = index.get("entries", {})
for eid in entry_ids:
if eid in entries:
entries[eid]["retrieval_count"] = entries[eid].get("retrieval_count", 0) + 1
entries[eid]["last_retrieved"] = now_str
if auto_save:
try:
save_index(index)
except Exception as exc:
logger.warning("queen_memory_index: failed to save index after retrieval: %s", exc)
# ---------------------------------------------------------------------------
# Hybrid re-ranking (Phase 4)
# ---------------------------------------------------------------------------
def hybrid_search(
query: str,
index: dict[str, Any],
candidate_ids: list[str],
semantic_scores: dict[str, float],
*,
keyword_weight: float = 0.3,
semantic_weight: float = 0.7,
) -> list[tuple[str, float]]:
"""Re-rank candidates combining semantic cosine, keyword overlap, and importance.
Combined score = semantic_weight * cosine
+ keyword_weight * keyword_overlap
+ 0.1 * normalised_importance
keyword_overlap = |query_terms entry.keywords| / max(1, |entry.keywords|)
normalised_importance is scaled to [0, 1] relative to the highest importance
in the candidate set.
"""
query_terms = set(re.findall(r"\w+", query.lower()))
now = datetime.now()
raw_scores: list[tuple[str, float]] = []
imp_values: list[float] = []
for eid in candidate_ids:
entry = get_entry(index, eid)
if entry is None:
continue
sem = semantic_scores.get(eid, 0.0)
kw_list = [k.lower() for k in entry.keywords]
overlap = len(query_terms & set(kw_list)) / max(1, len(kw_list))
imp = importance_score(entry, now)
imp_values.append(imp)
raw_scores.append((eid, sem, overlap, imp))
# Normalise importance to [0, 1]
max_imp = max(imp_values) if imp_values else 1.0
if max_imp == 0.0:
max_imp = 1.0
ranked: list[tuple[str, float]] = []
for eid, sem, overlap, imp in raw_scores:
score = (
semantic_weight * sem
+ keyword_weight * overlap
+ 0.1 * (imp / max_imp)
)
ranked.append((eid, score))
ranked.sort(key=lambda x: x[1], reverse=True)
return ranked
# ---------------------------------------------------------------------------
# Cross-reference linking (Phase 3)
# ---------------------------------------------------------------------------
def link_entry(
index: dict[str, Any],
entry_id: str,
similarity_threshold: float = 0.85,
) -> list[str]:
"""Discover k-NN above threshold and add bidirectional related[] links.
Mutates the index dict in place. Returns the list of newly linked
neighbour ids (may be empty).
"""
entries = index.get("entries", {})
raw = entries.get(entry_id)
if not raw or not raw.get("embedding"):
return []
neighbours = find_knn(raw["embedding"], index, k=10, exclude_id=entry_id)
linked: list[str] = []
for nid, sim in neighbours:
if sim < similarity_threshold:
break # sorted descending, so we can stop early
linked.append(nid)
# Update entry
if nid not in raw.setdefault("related", []):
raw["related"].append(nid)
# Update neighbour
neighbour = entries.get(nid)
if neighbour is not None and entry_id not in neighbour.setdefault("related", []):
neighbour["related"].append(entry_id)
return linked
# ---------------------------------------------------------------------------
# Prompt constants for LLM calls
# ---------------------------------------------------------------------------
_ENRICHMENT_SYSTEM = """\
Analyse the following diary entry from an AI assistant's episodic memory.
Extract structured metadata and return it as a JSON object with exactly these keys:
"keywords": list of 5-8 important terms (nouns, verbs, proper names)
"category": exactly one string from this list: agent_build, infrastructure,
user_preference, communication_style, diagnostic_learning, milestone,
pipeline, data_processing, other
"tags": list of 3-5 freeform topic labels (short phrases)
Return ONLY the JSON object. No explanation, no code fences.
"""
_EVOLUTION_SYSTEM = """\
You are refining the metadata of an older memory entry based on a newly discovered
related memory entry.
Given the TWO entries below, decide if the OLDER entry's tags or category should be
updated to better reflect the thematic connection.
Rules:
- Only suggest changes if the connection reveals a clearly missing tag or a category
correction. When in doubt, return {}.
- You may only modify "tags" and "category" never the prose, never keywords.
- Return a JSON object with only the keys you are changing: {"tags": [...], "category": "..."}
or {} if no change is warranted.
Return ONLY the JSON object. No explanation, no code fences.
"""
# ---------------------------------------------------------------------------
# Phase 1 — enrichment helpers
# ---------------------------------------------------------------------------
def _parse_diary_sections(content: str) -> list[tuple[str, str]]:
"""Return (timestamp, prose) pairs from a diary file's ### HH:MM blocks.
The date heading (# ...) is stripped. Non-timestamped content before the
first ### block is ignored.
"""
sections: list[tuple[str, str]] = []
# Split on ### HH:MM markers
parts = re.split(r"###\s*(\d{2}:\d{2})\b", content)
# parts = [pre_text, ts1, prose1, ts2, prose2, ...]
i = 1
while i + 1 < len(parts):
ts = parts[i].strip()
prose = parts[i + 1].strip()
if prose:
sections.append((ts, prose))
i += 2
return sections
def index_entry_from_diary_section(
date_str: str,
timestamp: str,
prose: str,
) -> MemoryEntry:
"""Construct a bare MemoryEntry (no enrichment, no embedding) from a diary section."""
entry_id = f"{date_str}:{timestamp}"
summary = prose[:300].replace("\n", " ")
return MemoryEntry(
id=entry_id,
date=date_str,
timestamp=timestamp,
summary=summary,
)
async def enrich_entry(
entry_text: str,
llm: object,
) -> tuple[list[str], str, list[str]]:
"""Call the consolidation LLM to extract keywords, category, and tags.
Returns ([], "other", []) on any failure so the caller can continue.
"""
try:
resp = await llm.acomplete(
messages=[{"role": "user", "content": entry_text}],
system=_ENRICHMENT_SYSTEM,
max_tokens=256,
json_mode=True,
)
data = json.loads(resp.content)
keywords = [str(k) for k in data.get("keywords", [])][:8]
raw_cat = str(data.get("category", "other"))
category = raw_cat if raw_cat in _CATEGORIES else "other"
tags = [str(t) for t in data.get("tags", [])][:5]
return keywords, category, tags
except Exception as exc:
logger.warning("queen_memory_index: enrich_entry failed (%s)", exc)
return [], "other", []
# ---------------------------------------------------------------------------
# Phase 5 — memory evolution
# ---------------------------------------------------------------------------
async def maybe_evolve_neighbors(
new_entry_id: str,
neighbor_ids: list[str],
index: dict[str, Any],
llm: object,
*,
max_neighbors_to_evolve: int = 2,
) -> None:
"""Potentially refine the tags/category of neighbour entries.
Only mutates metadata (tags, category) never prose, never embeddings.
Failures are logged and silently skipped.
"""
if not neighbor_ids:
return
new_raw = index.get("entries", {}).get(new_entry_id)
if not new_raw:
return
for nid in neighbor_ids[:max_neighbors_to_evolve]:
neighbor_raw = index.get("entries", {}).get(nid)
if not neighbor_raw:
continue
try:
prompt = (
f"NEWER entry ({new_entry_id}):\n"
f"Summary: {new_raw.get('summary', '')}\n"
f"Keywords: {', '.join(new_raw.get('keywords', []))}\n"
f"Tags: {', '.join(new_raw.get('tags', []))}\n\n"
f"OLDER entry ({nid}):\n"
f"Summary: {neighbor_raw.get('summary', '')}\n"
f"Keywords: {', '.join(neighbor_raw.get('keywords', []))}\n"
f"Tags: {', '.join(neighbor_raw.get('tags', []))}\n"
f"Category: {neighbor_raw.get('category', 'other')}"
)
resp = await llm.acomplete(
messages=[{"role": "user", "content": prompt}],
system=_EVOLUTION_SYSTEM,
max_tokens=128,
json_mode=True,
)
updates = json.loads(resp.content)
if not updates:
continue
if "tags" in updates and isinstance(updates["tags"], list):
neighbor_raw["tags"] = [str(t) for t in updates["tags"]][:5]
if "category" in updates:
raw_cat = str(updates["category"])
neighbor_raw["category"] = raw_cat if raw_cat in _CATEGORIES else "other"
logger.debug("queen_memory_index: evolved metadata for entry %s", nid)
except Exception as exc:
logger.warning("queen_memory_index: evolution failed for %s: %s", nid, exc)
# ---------------------------------------------------------------------------
# Index rebuild / backfill
# ---------------------------------------------------------------------------
async def rebuild_index_for_date(
date_str: str,
llm: object | None = None,
) -> int:
"""Parse today's diary file and index any sections not yet in the index.
Optionally enriches new entries via LLM if *llm* is provided.
Returns the count of new entries added.
"""
from framework.agents.queen.queen_memory import episodic_memory_path
from datetime import date as _date
try:
year, month, day = map(int, date_str.split("-"))
d = _date(year, month, day)
except ValueError:
logger.warning("queen_memory_index: invalid date_str %r", date_str)
return 0
ep_path = episodic_memory_path(d)
if not ep_path.exists():
return 0
content = ep_path.read_text(encoding="utf-8")
sections = _parse_diary_sections(content)
if not sections:
return 0
index = load_index()
# Detect embedding model change and clear stale vectors
if embeddings_enabled() and _detect_model_change(index):
_clear_embeddings(index)
added = 0
for ts, prose in sections:
entry_id = f"{date_str}:{ts}"
existing = get_entry(index, entry_id)
if existing is None:
entry = index_entry_from_diary_section(date_str, ts, prose)
elif existing.enriched:
# Already fully processed; update embedding only if missing
entry = existing
else:
entry = existing
# Enrich if LLM provided and not yet enriched
if llm is not None and not entry.enriched:
keywords, category, tags = await enrich_entry(prose, llm)
entry.keywords = keywords
entry.category = category
entry.tags = tags
entry.enriched = True
# Embed if model is configured and vector is missing
if embeddings_enabled() and entry.embedding is None:
vec = await embed_text(prose[:1500]) # cap input length
if vec is not None:
entry.embedding = vec
index["embed_model"] = get_embed_model()
index["embed_dim"] = len(vec)
put_entry(index, entry)
if existing is None:
added += 1
save_index(index)
logger.debug(
"queen_memory_index: indexed %d section(s) for %s, %d new", len(sections), date_str, added
)
return added
async def backfill_index(
llm: object | None = None,
embed: bool = True,
) -> dict[str, int]:
"""Walk all MEMORY-YYYY-MM-DD.md files and index unindexed entries.
This is a one-shot utility call it once after initial deployment to
catch up historical diary files. Not called automatically.
Usage:
uv run python -c "
import asyncio
from framework.agents.queen.queen_memory_index import backfill_index
print(asyncio.run(backfill_index()))
"
"""
memories_dir = _queen_memories_dir()
if not memories_dir.exists():
return {"dates_processed": 0, "entries_added": 0}
total_added = 0
dates_processed = 0
for md_file in sorted(memories_dir.glob("MEMORY-????-??-??.md")):
date_str = md_file.stem.removeprefix("MEMORY-")
if not re.fullmatch(r"\d{4}-\d{2}-\d{2}", date_str):
continue
added = await rebuild_index_for_date(date_str, llm=llm)
total_added += added
dates_processed += 1
logger.info(
"queen_memory_index: backfill complete — %d dates, %d entries added",
dates_processed,
total_added,
)
return {"dates_processed": dates_processed, "entries_added": total_added}
# ---------------------------------------------------------------------------
# Resolve full prose from diary file by entry_id
# ---------------------------------------------------------------------------
def resolve_prose(entry_id: str) -> str:
"""Read the source diary file and return the full prose for *entry_id*.
Returns the summary from the index as a fallback if the file section
cannot be found.
"""
from framework.agents.queen.queen_memory import episodic_memory_path
from datetime import date as _date
try:
date_str, ts = entry_id.split(":", 1)
year, month, day = map(int, date_str.split("-"))
d = _date(year, month, day)
except ValueError:
return ""
ep_path = episodic_memory_path(d)
if not ep_path.exists():
return ""
content = ep_path.read_text(encoding="utf-8")
sections = _parse_diary_sections(content)
for section_ts, prose in sections:
if section_ts == ts:
return prose
return ""
@@ -150,7 +150,7 @@ Call all three subagents in a single response to run them in parallel:
## GCU Anti-Patterns
- Using `browser_screenshot` to read text (use `browser_snapshot`)
- Using `browser_screenshot` to read text (use `browser_snapshot` instead; screenshots are for visual context only)
- Re-navigating after scrolling (resets scroll position)
- Attempting login on auth walls
- Forgetting `target_id` in multi-tab scenarios
@@ -0,0 +1,656 @@
"""Unit tests for queen_memory_index.py.
All tests run without HIVE_EMBED_MODEL set. Embedding behaviour is tested
via a lightweight mock that injects deterministic fixed vectors.
"""
from __future__ import annotations
import json
import math
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from framework.agents.queen.queen_memory_index import (
MemoryEntry,
_CATEGORIES,
_parse_diary_sections,
backfill_index,
cosine_similarity,
embed_text,
embeddings_enabled,
enrich_entry,
find_knn,
get_embed_model,
get_entry,
hybrid_search,
importance_score,
index_entry_from_diary_section,
index_path,
link_entry,
load_index,
maybe_evolve_neighbors,
put_entry,
rebuild_index_for_date,
record_retrieval,
resolve_prose,
save_index,
semantic_search,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_index(*entries: MemoryEntry) -> dict:
idx = {"version": 1, "embed_model": None, "embed_dim": None, "entries": {}}
for e in entries:
put_entry(idx, e)
return idx
def _entry(
date_str: str = "2026-03-01",
ts: str = "10:00",
summary: str = "test summary",
keywords: list[str] | None = None,
tags: list[str] | None = None,
category: str = "other",
embedding: list[float] | None = None,
retrieval_count: int = 0,
last_retrieved: str | None = None,
related: list[str] | None = None,
) -> MemoryEntry:
return MemoryEntry(
id=f"{date_str}:{ts}",
date=date_str,
timestamp=ts,
summary=summary,
keywords=keywords or [],
tags=tags or [],
category=category,
embedding=embedding,
retrieval_count=retrieval_count,
last_retrieved=last_retrieved,
related=related or [],
)
# ---------------------------------------------------------------------------
# cosine_similarity
# ---------------------------------------------------------------------------
class TestCosineSimilarity:
def test_identical_vectors(self):
v = [1.0, 0.0, 0.0]
assert cosine_similarity(v, v) == pytest.approx(1.0)
def test_orthogonal_vectors(self):
assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
def test_opposite_vectors(self):
# cosine of 180° = -1, but our vectors are floats so it can be -1
result = cosine_similarity([1.0, 0.0], [-1.0, 0.0])
assert result == pytest.approx(-1.0)
def test_none_inputs(self):
assert cosine_similarity(None, [1.0]) == 0.0
assert cosine_similarity([1.0], None) == 0.0
assert cosine_similarity(None, None) == 0.0
def test_zero_vector(self):
assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
def test_known_similarity(self):
# [1, 1] vs [1, 0] → cos(45°) ≈ 0.707
result = cosine_similarity([1.0, 1.0], [1.0, 0.0])
assert result == pytest.approx(math.sqrt(2) / 2, abs=1e-4)
# ---------------------------------------------------------------------------
# find_knn
# ---------------------------------------------------------------------------
class TestFindKnn:
def test_returns_sorted_descending(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.9, 0.1])
e3 = _entry("2026-03-01", "11:00", embedding=[0.0, 1.0])
idx = _make_index(e1, e2, e3)
results = find_knn([1.0, 0.0], idx, k=3)
ids = [r[0] for r in results]
scores = [r[1] for r in results]
assert ids[0] == "2026-03-01:09:00" # exact match
assert scores[0] == pytest.approx(1.0)
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
def test_excludes_self(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
idx = _make_index(e1)
results = find_knn([1.0, 0.0], idx, k=5, exclude_id="2026-03-01:09:00")
assert results == []
def test_skips_null_embeddings(self):
e1 = _entry("2026-03-01", "09:00", embedding=None)
e2 = _entry("2026-03-01", "10:00", embedding=[1.0, 0.0])
idx = _make_index(e1, e2)
results = find_knn([1.0, 0.0], idx, k=5)
ids = [r[0] for r in results]
assert "2026-03-01:09:00" not in ids
assert "2026-03-01:10:00" in ids
def test_respects_k(self):
entries = [_entry("2026-03-01", f"0{i}:00", embedding=[float(i), 0.0]) for i in range(5)]
idx = _make_index(*entries)
results = find_knn([1.0, 0.0], idx, k=2)
assert len(results) <= 2
# ---------------------------------------------------------------------------
# load_index / save_index (round-trip and atomic write)
# ---------------------------------------------------------------------------
class TestIndexIO:
def test_round_trip(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = _make_index(_entry())
idx["embed_model"] = "test-model"
save_index(idx)
loaded = load_index()
assert loaded["embed_model"] == "test-model"
assert "2026-03-01:10:00" in loaded["entries"]
def test_missing_file_returns_empty(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = load_index()
assert idx["entries"] == {}
assert idx["version"] == 1
def test_corrupt_file_returns_empty(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
(tmp_path / "index.json").write_text("not json at all", encoding="utf-8")
idx = load_index()
assert idx["entries"] == {}
def test_atomic_write_uses_tmp_then_rename(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = _make_index()
save_index(idx)
# tmp file should be gone after rename
assert not (tmp_path / "index.json.tmp").exists()
assert (tmp_path / "index.json").exists()
# ---------------------------------------------------------------------------
# get_entry / put_entry
# ---------------------------------------------------------------------------
class TestGetPutEntry:
def test_put_and_get_roundtrip(self):
e = _entry(keywords=["foo", "bar"], tags=["t1"], category="milestone")
idx = _make_index()
put_entry(idx, e)
loaded = get_entry(idx, e.id)
assert loaded is not None
assert loaded.keywords == ["foo", "bar"]
assert loaded.category == "milestone"
def test_get_missing_returns_none(self):
idx = _make_index()
assert get_entry(idx, "no-such-id") is None
def test_put_overwrites_existing(self):
e = _entry(summary="original")
idx = _make_index(e)
e2 = _entry(summary="updated")
put_entry(idx, e2)
loaded = get_entry(idx, e.id)
assert loaded.summary == "updated"
# ---------------------------------------------------------------------------
# index_entry_from_diary_section
# ---------------------------------------------------------------------------
class TestIndexEntryFromDiarySection:
def test_id_format(self):
e = index_entry_from_diary_section("2026-03-01", "14:30", "Some prose here.")
assert e.id == "2026-03-01:14:30"
assert e.date == "2026-03-01"
assert e.timestamp == "14:30"
def test_summary_truncated_to_300(self):
prose = "x" * 500
e = index_entry_from_diary_section("2026-03-01", "14:30", prose)
assert len(e.summary) == 300
def test_defaults_empty_enrichment(self):
e = index_entry_from_diary_section("2026-03-01", "14:30", "text")
assert e.keywords == []
assert e.tags == []
assert e.category == "other"
assert e.embedding is None
assert not e.enriched
# ---------------------------------------------------------------------------
# _parse_diary_sections
# ---------------------------------------------------------------------------
class TestParseDiarySections:
def test_parses_two_sections(self):
content = "# March 1, 2026\n\n### 09:00\n\nFirst entry.\n\n### 14:30\n\nSecond entry."
sections = _parse_diary_sections(content)
assert len(sections) == 2
assert sections[0] == ("09:00", "First entry.")
assert sections[1] == ("14:30", "Second entry.")
def test_ignores_content_before_first_timestamp(self):
content = "# Heading\n\nIntro text.\n\n### 10:00\n\nEntry."
sections = _parse_diary_sections(content)
assert len(sections) == 1
assert sections[0][0] == "10:00"
def test_empty_content(self):
assert _parse_diary_sections("") == []
def test_no_timestamp_sections(self):
assert _parse_diary_sections("# Just a heading\n\nSome text.") == []
# ---------------------------------------------------------------------------
# record_retrieval
# ---------------------------------------------------------------------------
class TestRecordRetrieval:
def test_increments_count(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
e = _entry(retrieval_count=2)
idx = _make_index(e)
record_retrieval(idx, [e.id], auto_save=False)
assert idx["entries"][e.id]["retrieval_count"] == 3
def test_sets_last_retrieved(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
e = _entry()
idx = _make_index(e)
record_retrieval(idx, [e.id], auto_save=False)
assert idx["entries"][e.id]["last_retrieved"] is not None
def test_ignores_missing_ids(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = _make_index()
# Should not raise
record_retrieval(idx, ["nonexistent:00:00"], auto_save=False)
# ---------------------------------------------------------------------------
# importance_score
# ---------------------------------------------------------------------------
class TestImportanceScore:
def test_zero_for_never_retrieved(self):
e = _entry(retrieval_count=0)
assert importance_score(e) == 0.0
def test_positive_for_retrieved_recently(self):
now = datetime.now()
e = _entry(retrieval_count=5, last_retrieved=now.isoformat())
score = importance_score(e, now=now)
assert score > 0.0
def test_decays_over_time(self):
from datetime import timedelta
now = datetime.now()
recent = _entry("2026-03-01", "10:00", retrieval_count=5,
last_retrieved=now.isoformat())
old = _entry("2026-03-01", "11:00", retrieval_count=5,
last_retrieved=(now - timedelta(days=60)).isoformat())
assert importance_score(recent, now=now) > importance_score(old, now=now)
def test_higher_count_higher_score(self):
now = datetime.now()
low = _entry("2026-03-01", "10:00", retrieval_count=1,
last_retrieved=now.isoformat())
high = _entry("2026-03-01", "11:00", retrieval_count=10,
last_retrieved=now.isoformat())
assert importance_score(high, now=now) > importance_score(low, now=now)
# ---------------------------------------------------------------------------
# link_entry (Phase 3)
# ---------------------------------------------------------------------------
class TestLinkEntry:
def test_links_above_threshold(self):
# Two nearly identical vectors should be linked
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.99, 0.01, 0.0])
idx = _make_index(e1, e2)
linked = link_entry(idx, e1.id, similarity_threshold=0.90)
assert e2.id in linked
def test_bidirectional_links(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[1.0, 0.0])
idx = _make_index(e1, e2)
link_entry(idx, e1.id, similarity_threshold=0.90)
assert e2.id in idx["entries"][e1.id]["related"]
assert e1.id in idx["entries"][e2.id]["related"]
def test_does_not_link_below_threshold(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.0, 1.0])
idx = _make_index(e1, e2)
linked = link_entry(idx, e1.id, similarity_threshold=0.90)
assert linked == []
def test_skips_entry_without_embedding(self):
e1 = _entry("2026-03-01", "09:00", embedding=None)
idx = _make_index(e1)
linked = link_entry(idx, e1.id)
assert linked == []
# ---------------------------------------------------------------------------
# hybrid_search (Phase 4)
# ---------------------------------------------------------------------------
class TestHybridSearch:
def test_semantic_score_dominates(self):
e_high = _entry("2026-03-01", "09:00", keywords=["unrelated"])
e_low = _entry("2026-03-01", "10:00", keywords=["pipeline", "agent"])
idx = _make_index(e_high, e_low)
sem_scores = {e_high.id: 0.95, e_low.id: 0.40}
ranked = hybrid_search("pipeline", idx, [e_high.id, e_low.id], sem_scores)
# e_high has much higher semantic score, should still rank first
assert ranked[0][0] == e_high.id
def test_keyword_overlap_breaks_tie(self):
e_kw = _entry("2026-03-01", "09:00", keywords=["pipeline", "agent", "workflow"])
e_no_kw = _entry("2026-03-01", "10:00", keywords=["unrelated", "other"])
idx = _make_index(e_kw, e_no_kw)
# Equal semantic scores
sem_scores = {e_kw.id: 0.80, e_no_kw.id: 0.80}
ranked = hybrid_search("pipeline agent", idx, [e_kw.id, e_no_kw.id], sem_scores)
assert ranked[0][0] == e_kw.id
def test_returns_sorted_descending(self):
entries = [_entry("2026-03-01", f"0{i}:00") for i in range(3)]
idx = _make_index(*entries)
sem_scores = {e.id: float(i) / 10 for i, e in enumerate(entries)}
ids = [e.id for e in entries]
ranked = hybrid_search("query", idx, ids, sem_scores)
scores = [s for _, s in ranked]
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
# ---------------------------------------------------------------------------
# embeddings_enabled / get_embed_model
# ---------------------------------------------------------------------------
class TestEmbeddingsEnabled:
def test_disabled_when_env_unset(self, monkeypatch):
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
assert not embeddings_enabled()
assert get_embed_model() is None
def test_enabled_when_env_set(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
assert embeddings_enabled()
assert get_embed_model() == "text-embedding-3-small"
# ---------------------------------------------------------------------------
# embed_text — mocked
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestEmbedText:
async def test_returns_none_when_disabled(self, monkeypatch):
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
result = await embed_text("hello")
assert result is None
async def test_returns_vector_when_enabled(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
fake_vec = [0.1, 0.2, 0.3]
mock_resp = MagicMock()
mock_resp.data = [{"embedding": fake_vec}]
with patch("litellm.aembedding", new=AsyncMock(return_value=mock_resp)):
result = await embed_text("hello world")
assert result == fake_vec
async def test_returns_none_on_api_failure(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
with patch("litellm.aembedding", new=AsyncMock(side_effect=RuntimeError("API down"))):
result = await embed_text("hello")
assert result is None
# ---------------------------------------------------------------------------
# semantic_search — mocked embeddings
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestSemanticSearch:
async def test_returns_empty_when_disabled(self, monkeypatch):
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
idx = _make_index(_entry(embedding=[1.0, 0.0]))
results = await semantic_search("query", idx)
assert results == []
async def test_finds_nearest_neighbours(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.0, 1.0])
idx = _make_index(e1, e2)
query_vec = [1.0, 0.0]
mock_resp = MagicMock()
mock_resp.data = [{"embedding": query_vec}]
with patch("litellm.aembedding", new=AsyncMock(return_value=mock_resp)):
results = await semantic_search("query", idx, k=2)
assert results[0][0] == e1.id # closest to [1.0, 0.0]
async def test_date_range_filter(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
e_in = _entry("2026-03-15", "09:00", embedding=[1.0, 0.0])
e_out = _entry("2026-02-01", "09:00", embedding=[1.0, 0.0])
idx = _make_index(e_in, e_out)
mock_resp = MagicMock()
mock_resp.data = [{"embedding": [1.0, 0.0]}]
with patch("litellm.aembedding", new=AsyncMock(return_value=mock_resp)):
results = await semantic_search(
"query", idx, k=10, date_range=("2026-03-01", "2026-03-31")
)
ids = [r[0] for r in results]
assert e_in.id in ids
assert e_out.id not in ids
# ---------------------------------------------------------------------------
# enrich_entry — mocked LLM
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestEnrichEntry:
async def test_parses_llm_response(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps(
{"keywords": ["pipeline", "agent"], "category": "pipeline", "tags": ["build", "test"]}
)
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
kw, cat, tags = await enrich_entry("Some diary text", mock_llm)
assert "pipeline" in kw
assert cat == "pipeline"
assert "build" in tags
async def test_rejects_invalid_category(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps(
{"keywords": [], "category": "invented_category", "tags": []}
)
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
_, cat, _ = await enrich_entry("text", mock_llm)
assert cat == "other"
async def test_returns_defaults_on_failure(self):
mock_llm = MagicMock()
mock_llm.acomplete = AsyncMock(side_effect=RuntimeError("LLM down"))
kw, cat, tags = await enrich_entry("text", mock_llm)
assert kw == []
assert cat == "other"
assert tags == []
# ---------------------------------------------------------------------------
# maybe_evolve_neighbors — mocked LLM
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestMaybeEvolveNeighbors:
async def test_updates_tags_on_non_empty_response(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps({"tags": ["new_tag", "updated"]})
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
new_e = _entry("2026-03-01", "10:00", keywords=["new"], tags=["tag_a"])
old_e = _entry("2026-03-01", "09:00", keywords=["old"], tags=["old_tag"])
idx = _make_index(new_e, old_e)
await maybe_evolve_neighbors(new_e.id, [old_e.id], idx, mock_llm)
assert "new_tag" in idx["entries"][old_e.id]["tags"]
async def test_no_op_on_empty_response(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps({})
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
new_e = _entry("2026-03-01", "10:00")
old_e = _entry("2026-03-01", "09:00", tags=["original"])
idx = _make_index(new_e, old_e)
await maybe_evolve_neighbors(new_e.id, [old_e.id], idx, mock_llm)
# Tags unchanged
assert idx["entries"][old_e.id]["tags"] == ["original"]
async def test_silently_handles_llm_failure(self):
mock_llm = MagicMock()
mock_llm.acomplete = AsyncMock(side_effect=RuntimeError("down"))
new_e = _entry("2026-03-01", "10:00")
old_e = _entry("2026-03-01", "09:00")
idx = _make_index(new_e, old_e)
# Must not raise
await maybe_evolve_neighbors(new_e.id, [old_e.id], idx, mock_llm)
async def test_respects_max_neighbors_cap(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps({})
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
new_e = _entry("2026-03-01", "10:00")
neighbors = [_entry("2026-03-01", f"0{i}:00") for i in range(5)]
idx = _make_index(new_e, *neighbors)
await maybe_evolve_neighbors(
new_e.id, [n.id for n in neighbors], idx, mock_llm, max_neighbors_to_evolve=2
)
assert mock_llm.acomplete.call_count == 2
# ---------------------------------------------------------------------------
# recall_diary — semantic path and fallback (integration-style)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestRecallDiary:
async def test_substring_fallback_when_embeddings_disabled(
self, tmp_path, monkeypatch
):
"""When HIVE_EMBED_MODEL is not set, recall_diary uses substring matching."""
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
# Write a fake diary file
memories_dir = tmp_path / ".hive" / "queen" / "memories"
memories_dir.mkdir(parents=True)
today_str = "2026-03-24"
(memories_dir / f"MEMORY-{today_str}.md").write_text(
"# March 24, 2026\n\n### 09:00\n\nWorked on the pipeline agent today.\n",
encoding="utf-8",
)
# Patch the path functions
import framework.agents.queen.queen_memory as qm
monkeypatch.setattr(qm, "episodic_memory_path", lambda d=None: memories_dir / f"MEMORY-{today_str}.md")
from framework.tools.queen_memory_tools import recall_diary
result = await recall_diary(query="pipeline", days_back=1)
assert "pipeline agent" in result
async def test_no_results_message(self, monkeypatch):
"""Returns a helpful message when nothing matches."""
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
import framework.agents.queen.queen_memory as qm
# Point to a non-existent path
monkeypatch.setattr(
qm, "episodic_memory_path", lambda d=None: Path("/nonexistent/MEMORY.md")
)
from framework.tools.queen_memory_tools import recall_diary
result = await recall_diary(query="nonexistent topic", days_back=1)
assert "No diary entries" in result
+10
View File
@@ -89,6 +89,16 @@ def main():
register_testing_commands(subparsers)
# Register skill commands (skill list, skill trust, ...)
from framework.skills.cli import register_skill_commands
register_skill_commands(subparsers)
# Register debugger commands (debugger)
from framework.debugger.cli import register_debugger_commands
register_debugger_commands(subparsers)
args = parser.parse_args()
if hasattr(args, "func"):
+281 -1
View File
@@ -61,6 +61,150 @@ def get_preferred_model() -> str:
return "anthropic/claude-sonnet-4-20250514"
def get_preferred_worker_model() -> str | None:
"""Return the user's preferred worker LLM model, or None if not configured.
Reads from the ``worker_llm`` section of ~/.hive/configuration.json.
Returns None when no worker-specific model is set, so callers can
fall back to the default (queen) model via ``get_preferred_model()``.
"""
worker_llm = get_hive_config().get("worker_llm", {})
if worker_llm.get("provider") and worker_llm.get("model"):
provider = str(worker_llm["provider"])
model = str(worker_llm["model"]).strip()
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
model = model[len("openrouter/") :]
if model:
return f"{provider}/{model}"
return None
def get_worker_api_key() -> str | None:
"""Return the API key for the worker LLM, falling back to the default key."""
worker_llm = get_hive_config().get("worker_llm", {})
if not worker_llm:
return get_api_key()
# Worker-specific subscription / env var
if worker_llm.get("use_claude_code_subscription"):
try:
from framework.runner.runner import get_claude_code_token
token = get_claude_code_token()
if token:
return token
except ImportError:
pass
if worker_llm.get("use_codex_subscription"):
try:
from framework.runner.runner import get_codex_token
token = get_codex_token()
if token:
return token
except ImportError:
pass
if worker_llm.get("use_kimi_code_subscription"):
try:
from framework.runner.runner import get_kimi_code_token
token = get_kimi_code_token()
if token:
return token
except ImportError:
pass
if worker_llm.get("use_antigravity_subscription"):
try:
from framework.runner.runner import get_antigravity_token
token = get_antigravity_token()
if token:
return token
except ImportError:
pass
api_key_env_var = worker_llm.get("api_key_env_var")
if api_key_env_var:
return os.environ.get(api_key_env_var)
# Fall back to default key
return get_api_key()
def get_worker_api_base() -> str | None:
"""Return the api_base for the worker LLM, falling back to the default."""
worker_llm = get_hive_config().get("worker_llm", {})
if not worker_llm:
return get_api_base()
if worker_llm.get("use_codex_subscription"):
return "https://chatgpt.com/backend-api/codex"
if worker_llm.get("use_kimi_code_subscription"):
return "https://api.kimi.com/coding"
if worker_llm.get("use_antigravity_subscription"):
# Antigravity uses AntigravityProvider directly — no api_base needed.
return None
if worker_llm.get("api_base"):
return worker_llm["api_base"]
if str(worker_llm.get("provider", "")).lower() == "openrouter":
return OPENROUTER_API_BASE
return None
def get_worker_llm_extra_kwargs() -> dict[str, Any]:
"""Return extra kwargs for the worker LLM provider."""
worker_llm = get_hive_config().get("worker_llm", {})
if not worker_llm:
return get_llm_extra_kwargs()
if worker_llm.get("use_claude_code_subscription"):
api_key = get_worker_api_key()
if api_key:
return {
"extra_headers": {"authorization": f"Bearer {api_key}"},
}
if worker_llm.get("use_codex_subscription"):
api_key = get_worker_api_key()
if api_key:
headers: dict[str, str] = {
"Authorization": f"Bearer {api_key}",
"User-Agent": "CodexBar",
}
try:
from framework.runner.runner import get_codex_account_id
account_id = get_codex_account_id()
if account_id:
headers["ChatGPT-Account-Id"] = account_id
except ImportError:
pass
return {
"extra_headers": headers,
"store": False,
"allowed_openai_params": ["store"],
}
return {}
def get_worker_max_tokens() -> int:
"""Return max_tokens for the worker LLM, falling back to default."""
worker_llm = get_hive_config().get("worker_llm", {})
if worker_llm and "max_tokens" in worker_llm:
return worker_llm["max_tokens"]
return get_max_tokens()
def get_worker_max_context_tokens() -> int:
"""Return max_context_tokens for the worker LLM, falling back to default."""
worker_llm = get_hive_config().get("worker_llm", {})
if worker_llm and "max_context_tokens" in worker_llm:
return worker_llm["max_context_tokens"]
return get_max_context_tokens()
def get_max_tokens() -> int:
"""Return the configured max_tokens, falling back to DEFAULT_MAX_TOKENS."""
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
@@ -120,6 +264,17 @@ def get_api_key() -> str | None:
except ImportError:
pass
# Antigravity subscription: read OAuth token from accounts JSON
if llm.get("use_antigravity_subscription"):
try:
from framework.runner.runner import get_antigravity_token
token = get_antigravity_token()
if token:
return token
except ImportError:
pass
# Standard env-var path (covers ZAI Code and all API-key providers)
api_key_env_var = llm.get("api_key_env_var")
if api_key_env_var:
@@ -127,6 +282,128 @@ def get_api_key() -> str | None:
return None
# OAuth credentials for Antigravity are fetched from the opencode-antigravity-auth project.
# This project reverse-engineered and published the public OAuth credentials
# for Google's Antigravity/Cloud Code Assist API.
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
_ANTIGRAVITY_CREDENTIALS_URL = (
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
)
_antigravity_credentials_cache: tuple[str | None, str | None] = (None, None)
def _fetch_antigravity_credentials() -> tuple[str | None, str | None]:
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
global _antigravity_credentials_cache
if _antigravity_credentials_cache[0] and _antigravity_credentials_cache[1]:
return _antigravity_credentials_cache
import re
import urllib.request
try:
req = urllib.request.Request(
_ANTIGRAVITY_CREDENTIALS_URL, headers={"User-Agent": "Hive/1.0"}
)
with urllib.request.urlopen(req, timeout=10) as resp:
content = resp.read().decode("utf-8")
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
client_id = id_match.group(1) if id_match else None
client_secret = secret_match.group(1) if secret_match else None
if client_id and client_secret:
_antigravity_credentials_cache = (client_id, client_secret)
return client_id, client_secret
except Exception as e:
logger.debug("Failed to fetch Antigravity credentials from public source: %s", e)
return None, None
def get_antigravity_client_id() -> str:
"""Return the Antigravity OAuth application client ID.
Checked in order:
1. ``ANTIGRAVITY_CLIENT_ID`` environment variable
2. ``llm.antigravity_client_id`` in ~/.hive/configuration.json
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
"""
env = os.environ.get("ANTIGRAVITY_CLIENT_ID")
if env:
return env
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_id")
if cfg_val:
return cfg_val
# Fetch from public source
client_id, _ = _fetch_antigravity_credentials()
if client_id:
return client_id
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
def get_antigravity_client_secret() -> str | None:
"""Return the Antigravity OAuth client secret.
Checked in order:
1. ``ANTIGRAVITY_CLIENT_SECRET`` environment variable
2. ``llm.antigravity_client_secret`` in ~/.hive/configuration.json
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
Returns None when not found token refresh will be skipped and
the caller must use whatever access token is already available.
"""
env = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
if env:
return env
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_secret") or None
if cfg_val:
return cfg_val
# Fetch from public source
_, secret = _fetch_antigravity_credentials()
return secret
def get_embed_model() -> str | None:
"""Return the configured embedding model string, or None if not set.
Reads from the ``embedding`` section of ~/.hive/configuration.json:
{
"embedding": {
"provider": "openai",
"model": "text-embedding-3-small",
"api_key_env_var": "OPENAI_API_KEY"
}
}
Returns a litellm-compatible ``"provider/model"`` string, e.g.
``"openai/text-embedding-3-small"``.
Falls back to the ``HIVE_EMBED_MODEL`` environment variable for
backward compatibility.
"""
embed = get_hive_config().get("embedding", {})
if embed.get("provider") and embed.get("model"):
provider = str(embed["provider"]).strip()
model = str(embed["model"]).strip()
if provider and model:
return f"{provider}/{model}"
return os.environ.get("HIVE_EMBED_MODEL") or None
def get_embed_api_key() -> str | None:
"""Return the API key for the embedding provider, or None if not set."""
embed = get_hive_config().get("embedding", {})
api_key_env_var = embed.get("api_key_env_var")
if api_key_env_var:
return os.environ.get(api_key_env_var)
return None
def get_embed_api_base() -> str | None:
"""Return a custom api_base for the embedding provider, or None."""
embed = get_hive_config().get("embedding", {})
return embed.get("api_base") or None
def get_gcu_enabled() -> bool:
"""Return whether GCU (browser automation) is enabled in user config."""
return get_hive_config().get("gcu_enabled", True)
@@ -149,6 +426,9 @@ def get_api_base() -> str | None:
if llm.get("use_kimi_code_subscription"):
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
return "https://api.kimi.com/coding"
if llm.get("use_antigravity_subscription"):
# Antigravity uses AntigravityProvider directly — no api_base needed.
return None
if llm.get("api_base"):
return llm["api_base"]
if str(llm.get("provider", "")).lower() == "openrouter":
@@ -198,7 +478,7 @@ def get_llm_extra_kwargs() -> dict[str, Any]:
# ---------------------------------------------------------------------------
# RuntimeConfig shared across agent templates
# RuntimeConfig - shared across agent templates
# ---------------------------------------------------------------------------
+26 -3
View File
@@ -27,6 +27,7 @@ from __future__ import annotations
import getpass
import json
import logging
import os
import sys
from collections.abc import Callable
@@ -37,6 +38,8 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from framework.graph import NodeSpec
logger = logging.getLogger(__name__)
# ANSI colors for terminal output
class Colors:
@@ -365,8 +368,11 @@ class CredentialSetupSession:
self._print("")
try:
api_key = self.password_fn(f"Paste your {cred.env_var}: ").strip()
except (EOFError, OSError) as exc:
logger.debug("Password input unavailable, falling back to plain input: %s", exc)
api_key = self._input(f"Paste your {cred.env_var}: ").strip()
except Exception:
# Fallback to regular input if password input fails
logger.warning("Unexpected error reading password input", exc_info=True)
api_key = self._input(f"Paste your {cred.env_var}: ").strip()
if not api_key:
@@ -403,7 +409,11 @@ class CredentialSetupSession:
try:
aden_key = self.password_fn("Paste your ADEN_API_KEY: ").strip()
except (EOFError, OSError) as exc:
logger.debug("Password input unavailable for ADEN_API_KEY: %s", exc)
aden_key = self._input("Paste your ADEN_API_KEY: ").strip()
except Exception:
logger.warning("Unexpected error reading ADEN_API_KEY input", exc_info=True)
aden_key = self._input("Paste your ADEN_API_KEY: ").strip()
if not aden_key:
@@ -433,8 +443,10 @@ class CredentialSetupSession:
value = store.get_key(cred_id, cred.credential_key)
if value:
os.environ[cred.env_var] = value
except (KeyError, OSError) as exc:
logger.debug("Could not export credential to env: %s", exc)
except Exception:
pass
logger.warning("Unexpected error exporting credential to env", exc_info=True)
return True
else:
self._print(
@@ -457,9 +469,12 @@ class CredentialSetupSession:
"message": result.message,
"details": result.details,
}
except Exception:
except ImportError:
# No health checker available
return None
except Exception:
logger.warning("Health check failed for %s", cred.credential_name, exc_info=True)
return None
def _store_credential(self, cred: MissingCredential, value: str) -> None:
"""Store credential in encrypted store and export to env."""
@@ -561,7 +576,11 @@ def _load_nodes_from_python_agent(agent_path: Path) -> list:
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return getattr(module, "nodes", [])
except (ImportError, OSError) as exc:
logger.debug("Could not load agent module: %s", exc)
return []
except Exception:
logger.warning("Unexpected error loading agent module", exc_info=True)
return []
@@ -588,7 +607,11 @@ def _load_nodes_from_json_agent(agent_json: Path) -> list:
)
)
return nodes
except (json.JSONDecodeError, KeyError, OSError) as exc:
logger.debug("Could not load JSON agent: %s", exc)
return []
except Exception:
logger.warning("Unexpected error loading JSON agent", exc_info=True)
return []
View File
+76
View File
@@ -0,0 +1,76 @@
"""CLI command for the LLM debug log viewer."""
import argparse
import subprocess
import sys
from pathlib import Path
_SCRIPT = Path(__file__).resolve().parents[3] / "scripts" / "llm_debug_log_visualizer.py"
def register_debugger_commands(subparsers: argparse._SubParsersAction) -> None:
"""Register the ``hive debugger`` command."""
parser = subparsers.add_parser(
"debugger",
help="Open the LLM debug log viewer",
description=(
"Start a local server that lets you browse LLM debug sessions "
"recorded in ~/.hive/llm_logs. Sessions are loaded on demand so "
"the browser stays responsive."
),
)
parser.add_argument(
"--session",
help="Execution ID to select initially.",
)
parser.add_argument(
"--port",
type=int,
default=0,
help="Port for the local server (0 = auto-pick a free port).",
)
parser.add_argument(
"--logs-dir",
help="Directory containing JSONL log files (default: ~/.hive/llm_logs).",
)
parser.add_argument(
"--limit-files",
type=int,
default=None,
help="Maximum number of newest log files to scan (default: 200).",
)
parser.add_argument(
"--output",
help="Write a static HTML file instead of starting a server.",
)
parser.add_argument(
"--no-open",
action="store_true",
help="Start the server but do not open a browser.",
)
parser.add_argument(
"--include-tests",
action="store_true",
help="Show test/mock sessions (hidden by default).",
)
parser.set_defaults(func=cmd_debugger)
def cmd_debugger(args: argparse.Namespace) -> int:
"""Launch the LLM debug log visualizer."""
cmd: list[str] = [sys.executable, str(_SCRIPT)]
if args.session:
cmd += ["--session", args.session]
if args.port:
cmd += ["--port", str(args.port)]
if args.logs_dir:
cmd += ["--logs-dir", args.logs_dir]
if args.limit_files is not None:
cmd += ["--limit-files", str(args.limit_files)]
if args.output:
cmd += ["--output", args.output]
if args.no_open:
cmd.append("--no-open")
if args.include_tests:
cmd.append("--include-tests")
return subprocess.call(cmd)
+30
View File
@@ -33,10 +33,20 @@ class Message:
is_transition_marker: bool = False
# True when this message is real human input (from /chat), not a system prompt
is_client_input: bool = False
# Optional image content blocks (e.g. from browser_screenshot)
image_content: list[dict[str, Any]] | None = None
# True when message contains an activated skill body (AS-10: never prune)
is_skill_content: bool = False
def to_llm_dict(self) -> dict[str, Any]:
"""Convert to OpenAI-format message dict."""
if self.role == "user":
if self.image_content:
blocks: list[dict[str, Any]] = []
if self.content:
blocks.append({"type": "text", "text": self.content})
blocks.extend(self.image_content)
return {"role": "user", "content": blocks}
return {"role": "user", "content": self.content}
if self.role == "assistant":
@@ -47,6 +57,15 @@ class Message:
# role == "tool"
content = f"ERROR: {self.content}" if self.is_error else self.content
if self.image_content:
# Multimodal tool result: text + image content blocks
blocks: list[dict[str, Any]] = [{"type": "text", "text": content}]
blocks.extend(self.image_content)
return {
"role": "tool",
"tool_call_id": self.tool_use_id,
"content": blocks,
}
return {
"role": "tool",
"tool_call_id": self.tool_use_id,
@@ -72,6 +91,8 @@ class Message:
d["is_transition_marker"] = self.is_transition_marker
if self.is_client_input:
d["is_client_input"] = self.is_client_input
if self.image_content is not None:
d["image_content"] = self.image_content
return d
@classmethod
@@ -87,6 +108,7 @@ class Message:
phase_id=data.get("phase_id"),
is_transition_marker=data.get("is_transition_marker", False),
is_client_input=data.get("is_client_input", False),
image_content=data.get("image_content"),
)
@@ -373,6 +395,7 @@ class NodeConversation:
*,
is_transition_marker: bool = False,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> Message:
msg = Message(
seq=self._next_seq,
@@ -381,6 +404,7 @@ class NodeConversation:
phase_id=self._current_phase,
is_transition_marker=is_transition_marker,
is_client_input=is_client_input,
image_content=image_content,
)
self._messages.append(msg)
self._next_seq += 1
@@ -409,6 +433,8 @@ class NodeConversation:
tool_use_id: str,
content: str,
is_error: bool = False,
image_content: list[dict[str, Any]] | None = None,
is_skill_content: bool = False,
) -> Message:
msg = Message(
seq=self._next_seq,
@@ -417,6 +443,8 @@ class NodeConversation:
tool_use_id=tool_use_id,
is_error=is_error,
phase_id=self._current_phase,
image_content=image_content,
is_skill_content=is_skill_content,
)
self._messages.append(msg)
self._next_seq += 1
@@ -610,6 +638,8 @@ class NodeConversation:
continue
if msg.is_error:
continue # never prune errors
if msg.is_skill_content:
continue # never prune activated skill instructions (AS-10)
if msg.content.startswith("[Pruned tool result"):
continue # already pruned
# Tiny results (set_output acks, confirmations) — pruning
+403 -32
View File
@@ -14,6 +14,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from collections.abc import Awaitable, Callable
@@ -24,6 +25,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from framework.graph.conversation import ConversationStore, NodeConversation
from framework.graph.node import NodeContext, NodeProtocol, NodeResult
from framework.llm.capabilities import supports_image_tool_results
from framework.llm.provider import Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
@@ -37,6 +39,56 @@ from framework.runtime.llm_debug_logger import log_llm_turn
logger = logging.getLogger(__name__)
async def _describe_images_as_text(image_content: list[dict[str, Any]]) -> str | None:
"""Describe images using the best available vision model.
Called when the queen's model lacks vision support. Tries vision-capable
models in priority order based on available API keys and returns a bracketed
description to inject into the message text, or None if no vision model is
reachable.
"""
import litellm
# Build content blocks: prompt + all images
blocks: list[dict[str, Any]] = [
{
"type": "text",
"text": (
"Describe the following image(s) concisely but with enough detail "
"that a text-only AI assistant can understand the content and context."
),
}
]
blocks.extend(image_content)
# Ordered candidates based on available env vars
candidates: list[str] = []
if os.environ.get("OPENAI_API_KEY"):
candidates.append("gpt-4o-mini")
if os.environ.get("ANTHROPIC_API_KEY"):
candidates.append("claude-3-haiku-20240307")
if os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"):
candidates.append("gemini/gemini-1.5-flash")
for model in candidates:
try:
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": blocks}],
max_tokens=512,
)
description = (response.choices[0].message.content or "").strip()
if description:
count = len(image_content)
label = "image" if count == 1 else f"{count} images"
return f"[{label} attached — description: {description}]"
except Exception as exc:
logger.debug("Vision fallback model '%s' failed: %s", model, exc)
continue
return None
@dataclass
class TriggerEvent:
"""A framework-level trigger signal (timer tick or webhook hit).
@@ -90,7 +142,13 @@ class _EscalationReceiver:
self._response: str | None = None
self._awaiting_input = True # So inject_worker_message() can prefer us
async def inject_event(self, content: str, *, is_client_input: bool = False) -> None:
async def inject_event(
self,
content: str,
*,
is_client_input: bool = False,
image_content: list[dict] | None = None,
) -> None:
"""Called by ExecutionStream.inject_input() when the user responds."""
self._response = content
self._event.set()
@@ -426,7 +484,9 @@ class EventLoopNode(NodeProtocol):
self._config = config or LoopConfig()
self._tool_executor = tool_executor
self._conversation_store = conversation_store
self._injection_queue: asyncio.Queue[tuple[str, bool]] = asyncio.Queue()
self._injection_queue: asyncio.Queue[tuple[str, bool, list[dict[str, Any]] | None]] = (
asyncio.Queue()
)
self._trigger_queue: asyncio.Queue[TriggerEvent] = asyncio.Queue()
# Client-facing input blocking state
self._input_ready = asyncio.Event()
@@ -467,6 +527,8 @@ class EventLoopNode(NodeProtocol):
stream_id = ctx.stream_id or ctx.node_id
node_id = ctx.node_id
execution_id = ctx.execution_id or ""
# Store skill dirs for AS-9 file-read interception in _execute_tool
self._skill_dirs: list[str] = ctx.skill_dirs
# Verdict counters for runtime logging
_accept_count = _retry_count = _escalate_count = _continue_count = 0
@@ -531,12 +593,28 @@ class EventLoopNode(NodeProtocol):
_restored_recent_responses = restored.recent_responses
_restored_tool_fingerprints = restored.recent_tool_fingerprints
# Refresh the system prompt with full 3-layer composition.
# The stored prompt may be stale after code changes or when
# runtime-injected context (e.g. worker identity) has changed.
# On resume, we rebuild identity + narrative + focus so the LLM
# understands the session history, not just the node directive.
from framework.graph.prompt_composer import compose_system_prompt
# Refresh the system prompt with full composition including
# execution preamble and node-type preamble. The stored
# prompt may be stale after code changes or when runtime-
# injected context (e.g. worker identity) has changed.
from framework.graph.prompt_composer import (
EXECUTION_SCOPE_PREAMBLE,
compose_system_prompt,
)
_exec_preamble = None
if (
not ctx.is_subagent_mode
and ctx.node_spec.node_type in ("event_loop", "gcu")
and ctx.node_spec.output_keys
):
_exec_preamble = EXECUTION_SCOPE_PREAMBLE
_node_type_preamble = None
if ctx.node_spec.node_type == "gcu":
from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT
_node_type_preamble = GCU_BROWSER_SYSTEM_PROMPT
_current_prompt = compose_system_prompt(
identity_prompt=ctx.identity_prompt or None,
@@ -545,6 +623,8 @@ class EventLoopNode(NodeProtocol):
accounts_prompt=ctx.accounts_prompt or None,
skills_catalog_prompt=ctx.skills_catalog_prompt or None,
protocols_prompt=ctx.protocols_prompt or None,
execution_preamble=_exec_preamble,
node_type_preamble=_node_type_preamble,
)
if conversation.system_prompt != _current_prompt:
conversation.update_system_prompt(_current_prompt)
@@ -764,7 +844,7 @@ class EventLoopNode(NodeProtocol):
)
# 6b. Drain injection queue
await self._drain_injection_queue(conversation)
await self._drain_injection_queue(conversation, ctx)
# 6b1. Drain trigger queue (framework-level signals)
await self._drain_trigger_queue(conversation)
@@ -806,6 +886,13 @@ class EventLoopNode(NodeProtocol):
execution_id,
extra_data=_iter_meta,
)
# Sync max_context_tokens from live config so mid-session model
# switches are reflected in compaction decisions and the UI bar.
from framework.config import get_max_context_tokens as _live_mct
conversation._max_context_tokens = _live_mct()
await self._publish_context_usage(ctx, conversation, "iteration_start")
# 6d. Pre-turn compaction check (tiered)
_compacted_this_iter = False
@@ -1883,7 +1970,13 @@ class EventLoopNode(NodeProtocol):
conversation=conversation if _is_continuous else None,
)
async def inject_event(self, content: str, *, is_client_input: bool = False) -> None:
async def inject_event(
self,
content: str,
*,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> None:
"""Inject an external event or user input into the running loop.
The content becomes a user message prepended to the next iteration.
@@ -1899,8 +1992,10 @@ class EventLoopNode(NodeProtocol):
human user (e.g. /chat endpoint), False for external events
(e.g. worker question forwarded by the frontend). Controls
message formatting in _drain_injection_queue, not wake behavior.
image_content: Optional list of image content blocks (OpenAI
image_url format) to include alongside the text.
"""
await self._injection_queue.put((content, is_client_input))
await self._injection_queue.put((content, is_client_input, image_content))
self._input_ready.set()
async def inject_trigger(self, trigger: TriggerEvent) -> None:
@@ -2074,6 +2169,24 @@ class EventLoopNode(NodeProtocol):
messages = conversation.to_llm_messages()
# Debug: log whether the last user message contains image blocks
for _m in reversed(messages):
if _m.get("role") == "user":
_content = _m.get("content")
if isinstance(_content, list):
_img_count = sum(
1
for _b in _content
if isinstance(_b, dict) and _b.get("type") == "image_url"
)
if _img_count:
logger.info(
"[%s] LLM call: last user message has %d image block(s)",
node_id,
_img_count,
)
break
# Defensive guard: ensure messages don't end with an assistant
# message. The Anthropic API rejects "assistant message prefill"
# (conversations must end with a user or tool message). This can
@@ -2477,6 +2590,27 @@ class EventLoopNode(NodeProtocol):
results_by_id[tc.tool_use_id] = result
elif tc.tool_name == "delegate_to_sub_agent":
# Guard: in continuous mode the LLM may see delegate
# calls from a previous node's conversation history and
# attempt to re-use the tool on a node that doesn't own
# it. Only accept if the tool was actually offered.
if not any(t.name == "delegate_to_sub_agent" for t in tools):
logger.warning(
"[%s] LLM called delegate_to_sub_agent but tool "
"was not offered to this node — rejecting",
node_id,
)
result = ToolResult(
tool_use_id=tc.tool_use_id,
content=(
"ERROR: delegate_to_sub_agent is not available "
"on this node. This tool belongs to a different "
"node in the workflow."
),
is_error=True,
)
results_by_id[tc.tool_use_id] = result
continue
# --- Framework-level subagent delegation ---
# Queue for parallel execution in Phase 2
logger.info(
@@ -2722,10 +2856,22 @@ class EventLoopNode(NodeProtocol):
real_tool_results.append(tool_entry)
logged_tool_calls.append(tool_entry)
# Strip image content for models that can't handle it
image_content = result.image_content
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
logger.info(
"Stripping image_content from tool result — model '%s' "
"does not support images in tool results",
ctx.llm.model,
)
image_content = None
await conversation.add_tool_result(
tool_use_id=tc.tool_use_id,
content=result.content,
is_error=result.is_error,
image_content=image_content,
is_skill_content=result.is_skill_content,
)
if (
tc.tool_name in ("ask_user", "ask_user_multiple")
@@ -2834,6 +2980,8 @@ class EventLoopNode(NodeProtocol):
conversation.usage_ratio() * 100,
)
await self._publish_context_usage(ctx, conversation, "post_tool_results")
# If the turn requested external input (ask_user or queen handoff),
# return immediately so the outer loop can block before judge eval.
if user_input_requested or queen_input_requested:
@@ -3549,6 +3697,33 @@ class EventLoopNode(NodeProtocol):
content=f"No tool executor configured for '{tc.tool_name}'",
is_error=True,
)
# AS-9: Intercept file-read tools for skill directories — bypass session sandbox
_SKILL_READ_TOOLS = {"view_file", "load_data", "read_file"}
skill_dirs = getattr(self, "_skill_dirs", [])
if tc.tool_name in _SKILL_READ_TOOLS and skill_dirs:
_path = tc.tool_input.get("path", "")
if _path:
import os
from pathlib import Path as _Path
_resolved = os.path.realpath(os.path.abspath(_path))
if any(_resolved.startswith(os.path.realpath(d)) for d in skill_dirs):
try:
_content = _Path(_resolved).read_text(encoding="utf-8")
_is_skill_md = _resolved.endswith("SKILL.md")
return ToolResult(
tool_use_id=tc.tool_use_id,
content=_content,
is_skill_content=_is_skill_md, # AS-10: protect SKILL.md reads
)
except Exception as _exc:
return ToolResult(
tool_use_id=tc.tool_use_id,
content=f"Could not read skill resource '{_path}': {_exc}",
is_error=True,
)
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
timeout = self._config.tool_call_timeout_seconds
@@ -3836,6 +4011,7 @@ class EventLoopNode(NodeProtocol):
tool_use_id=result.tool_use_id,
content=truncated,
is_error=False,
image_content=result.image_content,
)
spill_dir = self._config.spillover_dir
@@ -3908,6 +4084,7 @@ class EventLoopNode(NodeProtocol):
tool_use_id=result.tool_use_id,
content=content,
is_error=False,
image_content=result.image_content,
)
# No spillover_dir — truncate in-place if needed
@@ -3950,6 +4127,7 @@ class EventLoopNode(NodeProtocol):
tool_use_id=result.tool_use_id,
content=truncated,
is_error=False,
image_content=result.image_content,
)
return result
@@ -3980,6 +4158,12 @@ class EventLoopNode(NodeProtocol):
ratio_before = conversation.usage_ratio()
phase_grad = getattr(ctx, "continuous_mode", False)
# Capture pre-compaction message inventory when over budget,
# since compaction mutates the conversation in place.
pre_inventory: list[dict[str, Any]] | None = None
if ratio_before >= 1.0:
pre_inventory = self._build_message_inventory(conversation)
# --- Step 1: Prune old tool results (free, no LLM) ---
protect = max(2000, self._config.max_context_tokens // 12)
pruned = await conversation.prune_old_tool_results(
@@ -3994,7 +4178,7 @@ class EventLoopNode(NodeProtocol):
conversation.usage_ratio() * 100,
)
if not conversation.needs_compaction():
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
return
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
@@ -4007,7 +4191,7 @@ class EventLoopNode(NodeProtocol):
phase_graduated=phase_grad,
)
if not conversation.needs_compaction():
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
return
# --- Step 3: LLM summary compaction ---
@@ -4034,7 +4218,7 @@ class EventLoopNode(NodeProtocol):
logger.warning("LLM compaction failed: %s", e)
if not conversation.needs_compaction():
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
return
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
@@ -4048,7 +4232,7 @@ class EventLoopNode(NodeProtocol):
keep_recent=1,
phase_graduated=phase_grad,
)
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
# --- LLM compaction with binary-search splitting ----------------------
@@ -4210,13 +4394,59 @@ class EventLoopNode(NodeProtocol):
"re-doing work.\n"
)
@staticmethod
def _build_message_inventory(
conversation: NodeConversation,
) -> list[dict[str, Any]]:
"""Build a per-message size inventory for debug logging."""
inventory: list[dict[str, Any]] = []
for m in conversation.messages:
content_chars = len(m.content)
tc_chars = 0
tool_name = None
if m.tool_calls:
for tc in m.tool_calls:
args = tc.get("function", {}).get("arguments", "")
tc_chars += len(args) if isinstance(args, str) else len(json.dumps(args))
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
tool_name = ", ".join(names)
elif m.role == "tool" and m.tool_use_id:
for prev in conversation.messages:
if prev.tool_calls:
for tc in prev.tool_calls:
if tc.get("id") == m.tool_use_id:
tool_name = tc.get("function", {}).get("name", "?")
break
if tool_name:
break
entry: dict[str, Any] = {
"seq": m.seq,
"role": m.role,
"content_chars": content_chars,
}
if tc_chars:
entry["tool_call_args_chars"] = tc_chars
if tool_name:
entry["tool"] = tool_name
if m.is_error:
entry["is_error"] = True
if m.phase_id:
entry["phase"] = m.phase_id
if content_chars > 2000:
entry["preview"] = m.content[:200] + ""
inventory.append(entry)
return inventory
async def _log_compaction(
self,
ctx: NodeContext,
conversation: NodeConversation,
ratio_before: float,
pre_inventory: list[dict[str, Any]] | None = None,
) -> None:
"""Log compaction result to runtime logger and event bus."""
"""Log compaction result to runtime logger, event bus, and debug file."""
import os as _os
ratio_after = conversation.usage_ratio()
before_pct = round(ratio_before * 100)
after_pct = round(ratio_after * 100)
@@ -4249,19 +4479,103 @@ class EventLoopNode(NodeProtocol):
if self._event_bus:
from framework.runtime.event_bus import AgentEvent, EventType
event_data: dict[str, Any] = {
"level": level,
"usage_before": before_pct,
"usage_after": after_pct,
}
if pre_inventory is not None:
event_data["message_inventory"] = pre_inventory
await self._event_bus.publish(
AgentEvent(
type=EventType.CONTEXT_COMPACTED,
stream_id=ctx.stream_id or ctx.node_id,
node_id=ctx.node_id,
data={
"level": level,
"usage_before": before_pct,
"usage_after": after_pct,
},
data=event_data,
)
)
# Emit post-compaction usage update
await self._publish_context_usage(ctx, conversation, "post_compaction")
# Write detailed debug log to ~/.hive/compaction_log/ when enabled
if _os.environ.get("HIVE_COMPACTION_DEBUG"):
self._write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory)
@staticmethod
def _write_compaction_debug_log(
ctx: NodeContext,
before_pct: int,
after_pct: int,
level: str,
inventory: list[dict[str, Any]] | None,
) -> None:
"""Write detailed compaction analysis to ~/.hive/compaction_log/."""
log_dir = Path.home() / ".hive" / "compaction_log"
log_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
node_label = ctx.node_id.replace("/", "_")
log_path = log_dir / f"{ts}_{node_label}.md"
lines: list[str] = [
f"# Compaction Debug — {ctx.node_id}",
f"**Time:** {datetime.now(UTC).isoformat()}",
f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)",
]
if ctx.stream_id:
lines.append(f"**Stream:** {ctx.stream_id}")
lines.append(f"**Level:** {level}")
lines.append(f"**Usage:** {before_pct}% → {after_pct}%")
lines.append("")
if inventory:
total_chars = sum(
e.get("content_chars", 0) + e.get("tool_call_args_chars", 0) for e in inventory
)
lines.append(
f"## Pre-Compaction Message Inventory "
f"({len(inventory)} messages, {total_chars:,} total chars)"
)
lines.append("")
ranked = sorted(
inventory,
key=lambda e: e.get("content_chars", 0) + e.get("tool_call_args_chars", 0),
reverse=True,
)
lines.append("| # | seq | role | tool | chars | % of total | flags |")
lines.append("|---|-----|------|------|------:|------------|-------|")
for i, entry in enumerate(ranked, 1):
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
pct = (chars / total_chars * 100) if total_chars else 0
tool = entry.get("tool", "")
flags = []
if entry.get("is_error"):
flags.append("error")
if entry.get("phase"):
flags.append(f"phase={entry['phase']}")
lines.append(
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
)
large = [e for e in ranked if e.get("preview")]
if large:
lines.append("")
lines.append("### Large message previews")
for entry in large:
lines.append(
f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):"
)
lines.append(f"```\n{entry['preview']}\n```")
lines.append("")
try:
log_path.write_text("\n".join(lines), encoding="utf-8")
logger.debug("Compaction debug log written to %s", log_path)
except OSError:
logger.debug("Failed to write compaction debug log to %s", log_path)
def _build_emergency_summary(
self,
ctx: NodeContext,
@@ -4484,20 +4798,37 @@ class EventLoopNode(NodeProtocol):
]
await self._conversation_store.write_cursor(cursor)
async def _drain_injection_queue(self, conversation: NodeConversation) -> int:
async def _drain_injection_queue(self, conversation: NodeConversation, ctx: NodeContext) -> int:
"""Drain all pending injected events as user messages. Returns count."""
count = 0
while not self._injection_queue.empty():
try:
content, is_client_input = self._injection_queue.get_nowait()
content, is_client_input, image_content = self._injection_queue.get_nowait()
logger.info(
"[drain] injected message (client_input=%s): %s",
"[drain] injected message (client_input=%s, images=%d): %s",
is_client_input,
len(image_content) if image_content else 0,
content[:200] if content else "(empty)",
)
# For models that don't support images, fall back to a text description
if image_content and ctx.llm:
if not supports_image_tool_results(ctx.llm.model):
logger.info(
"Model '%s' does not support images — attempting vision fallback",
ctx.llm.model,
)
description = await _describe_images_as_text(image_content)
if description:
content = f"{content}\n\n{description}" if content else description
logger.info("[drain] image described as text via vision fallback")
else:
logger.info("[drain] no vision fallback available — images dropped")
image_content = None
# Real user input is stored as-is; external events get a prefix
if is_client_input:
await conversation.add_user_message(content, is_client_input=True)
await conversation.add_user_message(
content, is_client_input=True, image_content=image_content
)
else:
await conversation.add_user_message(f"[External event]: {content}")
count += 1
@@ -4666,6 +4997,36 @@ class EventLoopNode(NodeProtocol):
if result.inject:
await conversation.add_user_message(result.inject)
async def _publish_context_usage(
self,
ctx: NodeContext,
conversation: NodeConversation,
trigger: str,
) -> None:
"""Emit a CONTEXT_USAGE_UPDATED event with current context window state."""
if not self._event_bus:
return
from framework.runtime.event_bus import AgentEvent, EventType
estimated = conversation.estimate_tokens()
max_tokens = conversation._max_context_tokens
ratio = estimated / max_tokens if max_tokens > 0 else 0.0
await self._event_bus.publish(
AgentEvent(
type=EventType.CONTEXT_USAGE_UPDATED,
stream_id=ctx.stream_id or ctx.node_id,
node_id=ctx.node_id,
data={
"usage_ratio": round(ratio, 4),
"usage_pct": round(ratio * 100),
"message_count": conversation.message_count,
"estimated_tokens": estimated,
"max_context_tokens": max_tokens,
"trigger": trigger,
},
)
)
async def _publish_iteration(
self,
stream_id: str,
@@ -4950,7 +5311,20 @@ class EventLoopNode(NodeProtocol):
write_keys=[], # Read-only!
)
# 2b. Set up report callback (one-way channel to parent / event bus)
# 2b. Compute instance counter early so node_id is available for the
# report callback and the NodeContext. Each delegation to the same
# agent_id gets a unique suffix (instance 1 has no suffix for backward
# compat; instance 2+ appends ":N").
self._subagent_instance_counter.setdefault(agent_id, 0)
self._subagent_instance_counter[agent_id] += 1
_sa_instance = self._subagent_instance_counter[agent_id]
if _sa_instance > 1:
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}:{_sa_instance}"
else:
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}"
subagent_instance = str(_sa_instance)
# 2c. Set up report callback (one-way channel to parent / event bus)
subagent_reports: list[dict] = []
async def _report_callback(
@@ -4963,7 +5337,7 @@ class EventLoopNode(NodeProtocol):
if self._event_bus:
await self._event_bus.emit_subagent_report(
stream_id=ctx.node_id,
node_id=f"{ctx.node_id}:subagent:{agent_id}",
node_id=sa_node_id,
subagent_id=agent_id,
message=message,
data=data,
@@ -5053,7 +5427,7 @@ class EventLoopNode(NodeProtocol):
max_iter = min(self._config.max_iterations, 10)
subagent_ctx = NodeContext(
runtime=ctx.runtime,
node_id=f"{ctx.node_id}:subagent:{agent_id}",
node_id=sa_node_id,
node_spec=subagent_spec,
memory=scoped_memory,
input_data={"task": task, **parent_data},
@@ -5081,10 +5455,7 @@ class EventLoopNode(NodeProtocol):
# Derive a conversation store for the subagent from the parent's store.
# Each invocation gets a unique path so that repeated delegate calls
# (e.g. one per profile) don't restore a stale completed conversation.
self._subagent_instance_counter.setdefault(agent_id, 0)
self._subagent_instance_counter[agent_id] += 1
subagent_instance = str(self._subagent_instance_counter[agent_id])
# (Instance counter was computed earlier in step 2b.)
subagent_conv_store = None
if self._conversation_store is not None:
from framework.storage.conversation_store import FileConversationStore
+7
View File
@@ -154,6 +154,7 @@ class GraphExecutor:
iteration_metadata_provider: Callable | None = None,
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
):
"""
Initialize the executor.
@@ -181,6 +182,7 @@ class GraphExecutor:
system prompt (for phase switching)
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
"""
self.runtime = runtime
self.llm = llm
@@ -204,6 +206,7 @@ class GraphExecutor:
self.iteration_metadata_provider = iteration_metadata_provider
self.skills_catalog_prompt = skills_catalog_prompt
self.protocols_prompt = protocols_prompt
self.skill_dirs: list[str] = skill_dirs or []
if protocols_prompt:
self.logger.info(
@@ -1845,6 +1848,9 @@ class GraphExecutor:
existing_underscore = [k for k in memory._data if k.startswith("_")]
extra_keys = set(_skill_keys) | set(existing_underscore)
# Only inject into read_keys when it was already non-empty — an empty
# read_keys means "allow all reads" and injecting skill keys would
# inadvertently restrict reads to skill keys only.
for k in extra_keys:
if read_keys and k not in read_keys:
read_keys.append(k)
@@ -1899,6 +1905,7 @@ class GraphExecutor:
iteration_metadata_provider=self.iteration_metadata_provider,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
)
VALID_NODE_TYPES = {
+5 -2
View File
@@ -43,8 +43,11 @@ Follow these rules for reliable, efficient browser interaction.
`browser_snapshot` separately after every action.
Only call `browser_snapshot` when you need a fresh view without
performing an action, or after setting `auto_snapshot=false`.
- Do NOT use `browser_screenshot` for reading text content
it produces huge base64 images with no searchable text.
- Do NOT use `browser_screenshot` to read text use
`browser_snapshot` for that (compact, searchable, fast).
- DO use `browser_screenshot` when you need visual context:
charts, images, canvas elements, layout verification, or when
the snapshot doesn't capture what you need.
- Only fall back to `browser_get_text` for extracting specific
small elements by CSS selector.
-8
View File
@@ -167,14 +167,6 @@ class Goal(BaseModel):
return met_weight >= total_weight * 0.9 # 90% threshold
def check_constraint(self, constraint_id: str, value: Any) -> bool:
"""Check if a specific constraint is satisfied."""
for c in self.constraints:
if c.id == constraint_id:
# This would be expanded with actual evaluation logic
return True
return True
def to_prompt_context(self) -> str:
"""Generate context string for LLM prompts.
+1
View File
@@ -568,6 +568,7 @@ class NodeContext:
# Skill system prompts — injected by the skill discovery pipeline
skills_catalog_prompt: str = "" # Available skills XML catalog
protocols_prompt: str = "" # Default skill operational protocols
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
# Per-iteration metadata provider — when set, EventLoopNode merges
# the returned dict into node_loop_iteration event data. Used by
+15
View File
@@ -152,6 +152,8 @@ def compose_system_prompt(
accounts_prompt: str | None = None,
skills_catalog_prompt: str | None = None,
protocols_prompt: str | None = None,
execution_preamble: str | None = None,
node_type_preamble: str | None = None,
) -> str:
"""Compose the multi-layer system prompt.
@@ -162,6 +164,10 @@ def compose_system_prompt(
accounts_prompt: Connected accounts block (sits between identity and narrative).
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
protocols_prompt: Default skill operational protocols section.
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
(prepended before focus so the LLM knows its pipeline scope).
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
best-practices prompt (prepended before focus).
Returns:
Composed system prompt with all layers present, plus current datetime.
@@ -188,6 +194,15 @@ def compose_system_prompt(
if narrative:
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
# Execution scope preamble (worker nodes — tells the LLM it is one
# step in a multi-node pipeline and should not overreach)
if execution_preamble:
parts.append(f"\n{execution_preamble}")
# Node-type preamble (e.g. GCU browser best-practices)
if node_type_preamble:
parts.append(f"\n{node_type_preamble}")
# Layer 3: Focus (current phase directive)
if focus_prompt:
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
+706
View File
@@ -0,0 +1,706 @@
"""Antigravity (Google internal Cloud Code Assist) LLM provider.
Antigravity is Google's unified gateway API that routes requests to Gemini,
Claude, and GPT-OSS models through a single Gemini-style interface. It is
NOT the public ``generativelanguage.googleapis.com`` API.
Authentication uses Google OAuth2. Token refresh is done directly with the
OAuth client secret no local proxy required.
Credential sources (checked in order):
1. ``~/.hive/antigravity-accounts.json`` (native OAuth implementation)
2. Antigravity IDE SQLite state DB (macOS / Linux)
"""
from __future__ import annotations
import json
import logging
import re
import time
import uuid
from collections.abc import AsyncIterator, Callable, Iterator
from pathlib import Path
from typing import Any
from framework.llm.provider import LLMProvider, LLMResponse, Tool
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_TOKEN_URL = "https://oauth2.googleapis.com/token"
# Fallback order: daily sandbox → autopush sandbox → production
_ENDPOINTS = [
"https://daily-cloudcode-pa.sandbox.googleapis.com",
"https://autopush-cloudcode-pa.sandbox.googleapis.com",
"https://cloudcode-pa.googleapis.com",
]
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
_TOKEN_REFRESH_BUFFER_SECS = 60
# Credentials file in ~/.hive/ (native implementation)
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
_IDE_STATE_DB_MAC = (
Path.home()
/ "Library"
/ "Application Support"
/ "Antigravity"
/ "User"
/ "globalStorage"
/ "state.vscdb"
)
_IDE_STATE_DB_LINUX = (
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
)
_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
_BASE_HEADERS: dict[str, str] = {
# Mimic the Antigravity Electron app so the API accepts the request.
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 "
"(KHTML, like Gecko) Antigravity/1.18.3 Chrome/138.0.7204.235 "
"Electron/37.3.1 Safari/537.36"
),
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"Client-Metadata": '{"ideType":"ANTIGRAVITY","platform":"MACOS","pluginType":"GEMINI"}',
}
# ---------------------------------------------------------------------------
# Credential loading helpers
# ---------------------------------------------------------------------------
def _load_from_json_file() -> tuple[str | None, str | None, str, float]:
"""Read credentials from JSON accounts file.
Reads from ~/.hive/antigravity-accounts.json.
Returns ``(access_token | None, refresh_token | None, project_id, expires_at)``.
``expires_at`` is a Unix timestamp (seconds); 0.0 means unknown.
"""
if not _ACCOUNTS_FILE.exists():
return None, None, _DEFAULT_PROJECT_ID, 0.0
try:
with open(_ACCOUNTS_FILE, encoding="utf-8") as fh:
data = json.load(fh)
except (OSError, json.JSONDecodeError) as exc:
logger.debug("Failed to read Antigravity accounts file: %s", exc)
return None, None, _DEFAULT_PROJECT_ID, 0.0
accounts = data.get("accounts", [])
if not accounts:
return None, None, _DEFAULT_PROJECT_ID, 0.0
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
schema_version = data.get("schemaVersion", 1)
if schema_version >= 4:
# V4 schema: refresh = "refreshToken|projectId[|managedProjectId]"
refresh_str = account.get("refresh", "")
parts = refresh_str.split("|") if refresh_str else []
refresh_token: str | None = parts[0] if parts else None
project_id = parts[1] if len(parts) >= 2 and parts[1] else _DEFAULT_PROJECT_ID
access_token: str | None = account.get("access")
expires_ms: int = account.get("expires", 0)
expires_at = float(expires_ms) / 1000.0 if expires_ms else 0.0
# Treat near-expiry tokens as absent so _ensure_token() triggers a refresh.
if access_token and expires_at and time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
access_token = None
expires_at = 0.0
return access_token, refresh_token, project_id, expires_at
else:
# V1V3 schema: plain accessToken / refreshToken fields
access_token = account.get("accessToken")
refresh_token = account.get("refreshToken")
# Estimate expiry from last_refresh + 1 h
last_refresh_str: str | None = data.get("last_refresh")
expires_at = 0.0
if last_refresh_str:
try:
from datetime import datetime # noqa: PLC0415
ts = datetime.fromisoformat(last_refresh_str.replace("Z", "+00:00")).timestamp()
expires_at = ts + 3600.0
if time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
access_token = None
except (ValueError, TypeError):
pass
return access_token, refresh_token, _DEFAULT_PROJECT_ID, expires_at
def _load_from_ide_db() -> tuple[str | None, str | None, float]:
"""Extract ``(access_token, refresh_token, expires_at)`` from the IDE SQLite DB."""
import base64 # noqa: PLC0415
import sqlite3 # noqa: PLC0415
for db_path in (_IDE_STATE_DB_MAC, _IDE_STATE_DB_LINUX):
if not db_path.exists():
continue
try:
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
try:
row = con.execute(
"SELECT value FROM ItemTable WHERE key = ?",
(_IDE_STATE_DB_KEY,),
).fetchone()
finally:
con.close()
if not row:
continue
blob = base64.b64decode(row[0])
candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
access_token: str | None = None
refresh_token: str | None = None
for candidate in candidates:
try:
padded = candidate + b"=" * (-len(candidate) % 4)
inner = base64.urlsafe_b64decode(padded)
except Exception:
continue
if not access_token:
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
if m:
access_token = m.group(0).decode("ascii")
if not refresh_token:
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
if m:
refresh_token = m.group(0).decode("ascii")
if access_token and refresh_token:
break
if access_token:
# Estimate expiry from DB mtime (IDE refreshes while running)
mtime = db_path.stat().st_mtime
expires_at = mtime + 3600.0
return access_token, refresh_token, expires_at
except Exception as exc:
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
continue
return None, None, 0.0
def _do_token_refresh(refresh_token: str) -> tuple[str, float] | None:
"""POST to Google OAuth endpoint and return ``(new_access_token, expires_at)``.
The client secret is sourced via ``get_antigravity_client_secret()`` (env var,
config file, or npm package fallback). When unavailable the refresh is attempted
without it Google will reject it for web-app clients, but the npm fallback in
``get_antigravity_client_secret()`` should ensure the secret is found at runtime.
Returns None when the HTTP request fails.
"""
from framework.config import get_antigravity_client_secret # noqa: PLC0415
client_secret = get_antigravity_client_secret()
if not client_secret:
logger.debug(
"Antigravity client secret not configured — attempting refresh without it. "
"Set ANTIGRAVITY_CLIENT_SECRET or run quickstart to configure."
)
import urllib.error # noqa: PLC0415
import urllib.parse # noqa: PLC0415
import urllib.request # noqa: PLC0415
from framework.config import get_antigravity_client_id # noqa: PLC0415
params: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": get_antigravity_client_id(),
}
if client_secret:
params["client_secret"] = client_secret
body = urllib.parse.urlencode(params).encode("utf-8")
req = urllib.request.Request(
_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
payload = json.loads(resp.read())
access_token: str = payload["access_token"]
expires_in: int = payload.get("expires_in", 3600)
logger.debug("Antigravity token refreshed successfully")
return access_token, time.time() + expires_in
except Exception as exc:
logger.debug("Antigravity token refresh failed: %s", exc)
return None
# ---------------------------------------------------------------------------
# Message conversion helpers
# ---------------------------------------------------------------------------
def _clean_tool_name(name: str) -> str:
"""Sanitize a tool name for the Antigravity function-calling schema."""
name = re.sub(r"[/\s]", "_", name)
if name and not (name[0].isalpha() or name[0] == "_"):
name = "_" + name
return name[:64]
def _to_gemini_contents(
messages: list[dict[str, Any]],
thought_sigs: dict[str, str] | None = None,
) -> list[dict[str, Any]]:
"""Convert OpenAI-format messages to Gemini-style ``contents`` array."""
# Pre-build a map tool_call_id → function_name from assistant messages.
# Tool result messages (role="tool") only carry tool_call_id, not the name,
# but Gemini requires functionResponse.name to match the functionCall.name.
tc_id_to_name: dict[str, str] = {}
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
tc_id = tc.get("id")
fn_name = tc.get("function", {}).get("name", "")
if tc_id and fn_name:
tc_id_to_name[tc_id] = fn_name
contents: list[dict[str, Any]] = []
# Consecutive tool-result messages must be batched into one user turn.
pending_tool_parts: list[dict[str, Any]] = []
def _flush_tool_results() -> None:
if pending_tool_parts:
contents.append({"role": "user", "parts": list(pending_tool_parts)})
pending_tool_parts.clear()
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content")
if role == "system":
continue # Handled via systemInstruction, not in contents.
if role == "tool":
# OpenAI tool result → Gemini functionResponse part.
result_str = content if isinstance(content, str) else str(content or "")
tc_id = msg.get("tool_call_id", "")
# Look up function name from the pre-built map; fall back to msg.name.
fn_name = tc_id_to_name.get(tc_id) or msg.get("name", "")
pending_tool_parts.append(
{
"functionResponse": {
"name": fn_name,
"id": tc_id,
"response": {"content": result_str},
}
}
)
continue
_flush_tool_results()
gemini_role = "model" if role == "assistant" else "user"
parts: list[dict[str, Any]] = []
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, list):
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
text = block.get("text", "")
if text:
parts.append({"text": text})
# Other block types (image_url etc.) skipped.
# Assistant messages may carry OpenAI-style tool_calls.
for tc in msg.get("tool_calls") or []:
fn = tc.get("function", {})
try:
args = json.loads(fn.get("arguments", "{}") or "{}")
except (json.JSONDecodeError, TypeError):
args = {}
tc_id = tc.get("id", str(uuid.uuid4()))
fc_part: dict[str, Any] = {
"functionCall": {
"name": fn.get("name", ""),
"args": args,
"id": tc_id,
}
}
if thought_sigs:
sig = thought_sigs.get(tc_id, "")
if sig:
fc_part["thoughtSignature"] = sig # part-level, not inside functionCall
parts.append(fc_part)
if parts:
contents.append({"role": gemini_role, "parts": parts})
_flush_tool_results()
# Gemini requires the first turn to be a user turn. Drop any leading
# model messages so the API doesn't reject with a 400.
while contents and contents[0].get("role") == "model":
contents.pop(0)
return contents
# ---------------------------------------------------------------------------
# Response parsing helpers
# ---------------------------------------------------------------------------
def _map_finish_reason(reason: str) -> str:
return {"STOP": "stop", "MAX_TOKENS": "max_tokens", "OTHER": "tool_use"}.get(
(reason or "").upper(), "stop"
)
def _parse_complete_response(raw: dict[str, Any], model: str) -> LLMResponse:
"""Parse a non-streaming Antigravity response dict → LLMResponse."""
payload: dict[str, Any] = raw.get("response", raw)
candidates: list[dict[str, Any]] = payload.get("candidates", [])
usage: dict[str, Any] = payload.get("usageMetadata", {})
text_parts: list[str] = []
if candidates:
for part in candidates[0].get("content", {}).get("parts", []):
if "text" in part and not part.get("thought"):
text_parts.append(part["text"])
return LLMResponse(
content="".join(text_parts),
model=payload.get("modelVersion", model),
input_tokens=usage.get("promptTokenCount", 0),
output_tokens=usage.get("candidatesTokenCount", 0),
stop_reason=_map_finish_reason(candidates[0].get("finishReason", "") if candidates else ""),
raw_response=raw,
)
def _parse_sse_stream(
response: Any,
model: str,
on_thought_signature: Callable[[str, str], None] | None = None,
) -> Iterator[StreamEvent]:
"""Parse Antigravity SSE response line-by-line → StreamEvents.
Each SSE line looks like::
data: {"response": {"candidates": [...], "usageMetadata": {...}}, "traceId": "..."}
"""
accumulated = ""
input_tokens = 0
output_tokens = 0
finish_reason = ""
for raw_line in response:
line: str = raw_line.decode("utf-8", errors="replace").rstrip("\r\n")
if not line.startswith("data:"):
continue
data_str = line[5:].strip()
if not data_str or data_str == "[DONE]":
continue
try:
data: dict[str, Any] = json.loads(data_str)
except json.JSONDecodeError:
continue
# The outer envelope is {"response": {...}, "traceId": "..."}.
payload: dict[str, Any] = data.get("response", data)
usage = payload.get("usageMetadata", {})
if usage:
input_tokens = usage.get("promptTokenCount", input_tokens)
output_tokens = usage.get("candidatesTokenCount", output_tokens)
for candidate in payload.get("candidates", []):
fr = candidate.get("finishReason", "")
if fr:
finish_reason = fr
for part in candidate.get("content", {}).get("parts", []):
if "text" in part and not part.get("thought"):
delta: str = part["text"]
accumulated += delta
yield TextDeltaEvent(content=delta, snapshot=accumulated)
elif "functionCall" in part:
fc: dict[str, Any] = part["functionCall"]
tool_use_id = fc.get("id") or str(uuid.uuid4())
thought_sig = part.get("thoughtSignature", "") # sibling of functionCall
if thought_sig and on_thought_signature:
on_thought_signature(tool_use_id, thought_sig)
args = fc.get("args", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
yield ToolCallEvent(
tool_use_id=tool_use_id,
tool_name=fc.get("name", ""),
tool_input=args,
)
if accumulated:
yield TextEndEvent(full_text=accumulated)
yield FinishEvent(
stop_reason=_map_finish_reason(finish_reason),
input_tokens=input_tokens,
output_tokens=output_tokens,
model=model,
)
# ---------------------------------------------------------------------------
# Provider
# ---------------------------------------------------------------------------
class AntigravityProvider(LLMProvider):
"""LLM provider for Google's internal Antigravity Code Assist gateway.
No local proxy required. Handles OAuth token refresh, Gemini-format
request/response conversion, and SSE streaming directly.
"""
def __init__(self, model: str = "gemini-3-flash") -> None:
# Strip any provider prefix ("openai/gemini-3-flash" → "gemini-3-flash").
if "/" in model:
model = model.split("/", 1)[1]
self.model = model
self._access_token: str | None = None
self._refresh_token: str | None = None
self._project_id: str = _DEFAULT_PROJECT_ID
self._token_expires_at: float = 0.0
self._thought_sigs: dict[str, str] = {} # tool_use_id → thoughtSignature
self._init_credentials()
# --- Credential management -------------------------------------------- #
def _init_credentials(self) -> None:
"""Load credentials from the best available source."""
access, refresh, project_id, expires_at = _load_from_json_file()
if refresh:
self._refresh_token = refresh
self._project_id = project_id
self._access_token = access
self._token_expires_at = expires_at
return
# Fall back to IDE state DB.
access, refresh, expires_at = _load_from_ide_db()
if access:
self._access_token = access
self._refresh_token = refresh
self._token_expires_at = expires_at
def has_credentials(self) -> bool:
"""Return True if any credential is available."""
return bool(self._access_token or self._refresh_token)
def _ensure_token(self) -> str:
"""Return a valid access token, refreshing via OAuth if needed."""
if (
self._access_token
and self._token_expires_at
and time.time() < self._token_expires_at - _TOKEN_REFRESH_BUFFER_SECS
):
return self._access_token
if self._refresh_token:
result = _do_token_refresh(self._refresh_token)
if result:
self._access_token, self._token_expires_at = result
return self._access_token
if self._access_token:
logger.warning("Using potentially stale Antigravity access token")
return self._access_token
raise RuntimeError(
"No valid Antigravity credentials. "
"Run: uv run python core/antigravity_auth.py auth account add"
)
# --- Request building -------------------------------------------------- #
def _build_body(
self,
messages: list[dict[str, Any]],
system: str,
tools: list[Tool] | None,
max_tokens: int,
) -> dict[str, Any]:
contents = _to_gemini_contents(messages, self._thought_sigs)
inner: dict[str, Any] = {
"contents": contents,
"generationConfig": {"maxOutputTokens": max_tokens},
}
if system:
inner["systemInstruction"] = {"parts": [{"text": system}]}
if tools:
inner["tools"] = [
{
"functionDeclarations": [
{
"name": _clean_tool_name(t.name),
"description": t.description,
"parameters": t.parameters
or {
"type": "object",
"properties": {},
},
}
for t in tools
]
}
]
return {
"project": self._project_id,
"model": self.model,
"request": inner,
"requestType": "agent",
"userAgent": "antigravity",
"requestId": f"agent-{uuid.uuid4()}",
}
# --- HTTP transport ---------------------------------------------------- #
def _post(self, body: dict[str, Any], *, streaming: bool) -> Any:
"""POST to the Antigravity endpoint, falling back through the endpoint list."""
import urllib.error # noqa: PLC0415
import urllib.request # noqa: PLC0415
token = self._ensure_token()
body_bytes = json.dumps(body).encode("utf-8")
path = (
"/v1internal:streamGenerateContent?alt=sse"
if streaming
else "/v1internal:generateContent"
)
headers = {
**_BASE_HEADERS,
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
if streaming:
headers["Accept"] = "text/event-stream"
last_exc: Exception | None = None
for base_url in _ENDPOINTS:
url = f"{base_url}{path}"
req = urllib.request.Request(url, data=body_bytes, headers=headers, method="POST")
try:
return urllib.request.urlopen(req, timeout=120) # noqa: S310
except urllib.error.HTTPError as exc:
if exc.code in (401, 403) and self._refresh_token:
# Token rejected — refresh once and retry this endpoint.
result = _do_token_refresh(self._refresh_token)
if result:
self._access_token, self._token_expires_at = result
headers["Authorization"] = f"Bearer {self._access_token}"
req2 = urllib.request.Request(
url, data=body_bytes, headers=headers, method="POST"
)
try:
return urllib.request.urlopen(req2, timeout=120) # noqa: S310
except urllib.error.HTTPError as exc2:
last_exc = exc2
continue
last_exc = exc
continue
elif exc.code >= 500:
last_exc = exc
continue
# Include the API response body in the exception for easier debugging.
try:
err_body = exc.read().decode("utf-8", errors="replace")
except Exception:
err_body = "(unreadable)"
raise RuntimeError(f"Antigravity HTTP {exc.code} from {url}: {err_body}") from exc
except (urllib.error.URLError, OSError) as exc:
last_exc = exc
continue
raise RuntimeError(
f"All Antigravity endpoints failed. Last error: {last_exc}"
) from last_exc
# --- LLMProvider interface --------------------------------------------- #
def complete(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 1024,
response_format: dict[str, Any] | None = None,
json_mode: bool = False,
max_retries: int | None = None,
) -> LLMResponse:
if json_mode:
suffix = "\n\nPlease respond with a valid JSON object."
system = (system + suffix) if system else suffix.strip()
body = self._build_body(messages, system, tools, max_tokens)
resp = self._post(body, streaming=False)
return _parse_complete_response(json.loads(resp.read()), self.model)
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
import asyncio # noqa: PLC0415
import concurrent.futures # noqa: PLC0415
loop = asyncio.get_running_loop()
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
def _blocking_work() -> None:
try:
body = self._build_body(messages, system, tools, max_tokens)
http_resp = self._post(body, streaming=True)
for event in _parse_sse_stream(
http_resp, self.model, self._thought_sigs.__setitem__
):
loop.call_soon_threadsafe(queue.put_nowait, event)
except Exception as exc:
logger.error("Antigravity stream error: %s", exc)
loop.call_soon_threadsafe(queue.put_nowait, StreamErrorEvent(error=str(exc)))
finally:
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
fut = loop.run_in_executor(executor, _blocking_work)
try:
while True:
event = await queue.get()
if event is None:
break
yield event
finally:
await fut
executor.shutdown(wait=False)
+106
View File
@@ -0,0 +1,106 @@
"""Model capability checks for LLM providers.
Vision support rules are derived from official vendor documentation:
- ZAI (z.ai): docs.z.ai/guides/vlm GLM-4.6V variants are vision; GLM-5/4.6/4.7 are text-only
- MiniMax: platform.minimax.io/docs minimax-vl-01 is vision; M2.x are text-only
- DeepSeek: api-docs.deepseek.com deepseek-vl2 is vision; chat/reasoner are text-only
- Cerebras: inference-docs.cerebras.ai no vision models at all
- Groq: console.groq.com/docs/vision vision capable; treat as supported by default
- Ollama/LM Studio/vLLM/llama.cpp: local runners denied by default; model names
don't reliably indicate vision support, so users must configure explicitly
"""
from __future__ import annotations
def _model_name(model: str) -> str:
"""Return the bare model name after stripping any 'provider/' prefix."""
if "/" in model:
return model.split("/", 1)[1]
return model
# Step 1: explicit vision allow-list — these always support images regardless
# of what the provider-level rules say. Checked first so that e.g. glm-4.6v
# is allowed even though glm-4.6 is denied.
_VISION_ALLOW_BARE_PREFIXES: tuple[str, ...] = (
# ZAI/GLM vision models (docs.z.ai/guides/vlm)
"glm-4v", # GLM-4V series (legacy)
"glm-4.6v", # GLM-4.6V, GLM-4.6V-flash, GLM-4.6V-flashx
# DeepSeek vision models
"deepseek-vl", # deepseek-vl2, deepseek-vl2-small, deepseek-vl2-tiny
# MiniMax vision model
"minimax-vl", # minimax-vl-01
)
# Step 2: provider-level deny — every model from this provider is text-only.
_TEXT_ONLY_PROVIDER_PREFIXES: tuple[str, ...] = (
# Cerebras: inference-docs.cerebras.ai lists only text models
"cerebras/",
# Local runners: model names don't reliably indicate vision support
"ollama/",
"ollama_chat/",
"lm_studio/",
"vllm/",
"llamacpp/",
)
# Step 3: per-model deny — text-only models within otherwise mixed providers.
# Matched against the bare model name (provider prefix stripped, lower-cased).
# The vision allow-list above is checked first, so vision variants of the same
# family are already handled before these deny patterns are reached.
_TEXT_ONLY_MODEL_BARE_PREFIXES: tuple[str, ...] = (
# --- ZAI / GLM family ---
# text-only: glm-5, glm-4.6, glm-4.7, glm-4.5, zai-glm-*
# vision: glm-4v, glm-4.6v (caught by allow-list above)
"glm-5",
"glm-4.6", # bare glm-4.6 is text-only; glm-4.6v is caught by allow-list
"glm-4.7",
"glm-4.5",
"zai-glm",
# --- DeepSeek ---
# text-only: deepseek-chat, deepseek-coder, deepseek-reasoner
# vision: deepseek-vl2 (caught by allow-list above)
# Note: LiteLLM's deepseek handler may flatten content lists for some models;
# VL models are allowed through and rely on LiteLLM's native VL support.
"deepseek-chat",
"deepseek-coder",
"deepseek-reasoner",
# --- MiniMax ---
# text-only: minimax-m2.*, minimax-text-*, abab* (legacy)
# vision: minimax-vl-01 (caught by allow-list above)
"minimax-m2",
"minimax-text",
"abab",
)
def supports_image_tool_results(model: str) -> bool:
"""Return whether *model* can receive image content in messages.
Used to gate both user-message images and tool-result image blocks.
Logic (checked in order):
1. Vision allow-list True (known vision model, skip all denies)
2. Provider deny False (entire provider is text-only)
3. Model deny False (specific text-only model within a mixed provider)
4. Default True (assume capable; unknown providers and models)
"""
model_lower = model.lower()
bare = _model_name(model_lower)
# 1. Explicit vision allow — takes priority over all denies
if any(bare.startswith(p) for p in _VISION_ALLOW_BARE_PREFIXES):
return True
# 2. Provider-level deny (all models from this provider are text-only)
if any(model_lower.startswith(p) for p in _TEXT_ONLY_PROVIDER_PREFIXES):
return False
# 3. Per-model deny (text-only variants within mixed-capability families)
if any(bare.startswith(p) for p in _TEXT_ONLY_MODEL_BARE_PREFIXES):
return False
# 5. Default: assume vision capable
# Covers: OpenAI, Anthropic, Google, Mistral, Kimi, and other hosted providers
return True
+98 -3
View File
@@ -9,8 +9,10 @@ See: https://docs.litellm.ai/docs/providers
import ast
import asyncio
import hashlib
import json
import logging
import os
import re
import time
from collections.abc import AsyncIterator
@@ -46,7 +48,10 @@ def _patch_litellm_anthropic_oauth() -> None:
"""
try:
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
from litellm.types.llms.anthropic import ANTHROPIC_OAUTH_TOKEN_PREFIX
from litellm.types.llms.anthropic import (
ANTHROPIC_OAUTH_BETA_HEADER,
ANTHROPIC_OAUTH_TOKEN_PREFIX,
)
except ImportError:
logger.warning(
"Could not apply litellm Anthropic OAuth patch — litellm internals may have "
@@ -71,9 +76,27 @@ def _patch_litellm_anthropic_oauth() -> None:
api_key=api_key,
api_base=api_base,
)
# Check both authorization header and x-api-key for OAuth tokens.
# litellm's optionally_handle_anthropic_oauth only checks headers["authorization"],
# but hive passes OAuth tokens via api_key — so litellm puts them into x-api-key.
# Anthropic rejects OAuth tokens in x-api-key; they must go in Authorization: Bearer.
auth = result.get("authorization", "")
if auth.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
x_api_key = result.get("x-api-key", "")
oauth_prefix = f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"
auth_is_oauth = auth.startswith(oauth_prefix)
key_is_oauth = x_api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
if auth_is_oauth or key_is_oauth:
token = x_api_key if key_is_oauth else auth.removeprefix("Bearer ").strip()
result.pop("x-api-key", None)
result["authorization"] = f"Bearer {token}"
# Merge the OAuth beta header with any existing beta headers.
existing_beta = result.get("anthropic-beta", "")
beta_parts = (
[b.strip() for b in existing_beta.split(",") if b.strip()] if existing_beta else []
)
if ANTHROPIC_OAUTH_BETA_HEADER not in beta_parts:
beta_parts.append(ANTHROPIC_OAUTH_BETA_HEADER)
result["anthropic-beta"] = ",".join(beta_parts)
return result
AnthropicModelInfo.validate_environment = _patched_validate_environment
@@ -132,6 +155,9 @@ def _patch_litellm_metadata_nonetype() -> None:
if litellm is not None:
_patch_litellm_anthropic_oauth()
_patch_litellm_metadata_nonetype()
# Let litellm silently drop params unsupported by the target provider
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
litellm.drop_params = True
RATE_LIMIT_MAX_RETRIES = 10
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
@@ -162,6 +188,53 @@ def _model_supports_cache_control(model: str) -> bool:
# enforces a coding-agent whitelist that blocks unknown User-Agents.
KIMI_API_BASE = "https://api.kimi.com/coding"
# Claude Code OAuth subscription: the Anthropic API requires a specific
# User-Agent and a billing integrity header for OAuth-authenticated requests.
CLAUDE_CODE_VERSION = "2.1.76"
CLAUDE_CODE_USER_AGENT = f"claude-code/{CLAUDE_CODE_VERSION}"
_CLAUDE_CODE_BILLING_SALT = "59cf53e54c78"
def _sample_js_code_unit(text: str, idx: int) -> str:
"""Return the character at UTF-16 code unit index *idx*, matching JS semantics."""
encoded = text.encode("utf-16-le")
unit_offset = idx * 2
if unit_offset + 2 > len(encoded):
return "0"
code_unit = int.from_bytes(encoded[unit_offset : unit_offset + 2], "little")
return chr(code_unit)
def _claude_code_billing_header(messages: list[dict[str, Any]]) -> str:
"""Build the billing integrity system block required by Anthropic's OAuth path."""
# Find the first user message text
first_text = ""
for msg in messages:
if msg.get("role") != "user":
continue
content = msg.get("content")
if isinstance(content, str):
first_text = content
break
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text" and block.get("text"):
first_text = block["text"]
break
if first_text:
break
sampled = "".join(_sample_js_code_unit(first_text, i) for i in (4, 7, 20))
version_hash = hashlib.sha256(
f"{_CLAUDE_CODE_BILLING_SALT}{sampled}{CLAUDE_CODE_VERSION}".encode()
).hexdigest()
entrypoint = os.environ.get("CLAUDE_CODE_ENTRYPOINT", "").strip() or "cli"
return (
f"x-anthropic-billing-header: cc_version={CLAUDE_CODE_VERSION}.{version_hash[:3]}; "
f"cc_entrypoint={entrypoint}; cch=00000;"
)
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
# Conversation-structure issues are deterministic — long waits don't help.
EMPTY_STREAM_MAX_RETRIES = 3
@@ -441,11 +514,19 @@ class LiteLLMProvider(LLMProvider):
self.api_key = api_key
self.api_base = api_base or self._default_api_base_for_model(_original_model)
self.extra_kwargs = kwargs
# Detect Claude Code OAuth subscription by checking the api_key prefix.
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
if self._claude_code_oauth:
# Anthropic requires a specific User-Agent for OAuth requests.
eh = self.extra_kwargs.setdefault("extra_headers", {})
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
# several standard OpenAI params: max_output_tokens, stream_options.
self._codex_backend = bool(
self.api_base and "chatgpt.com/backend-api/codex" in self.api_base
)
# Antigravity routes through a local OpenAI-compatible proxy — no patches needed.
self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base)
if litellm is None:
raise ImportError(
@@ -808,6 +889,9 @@ class LiteLLMProvider(LLMProvider):
return await self._collect_stream_to_response(stream_iter)
full_messages: list[dict[str, Any]] = []
if self._claude_code_oauth:
billing = _claude_code_billing_header(messages)
full_messages.append({"role": "system", "content": billing})
if system:
sys_msg: dict[str, Any] = {"role": "system", "content": system}
if _model_supports_cache_control(self.model):
@@ -869,6 +953,11 @@ class LiteLLMProvider(LLMProvider):
},
}
def _is_anthropic_model(self) -> bool:
"""Return True when the configured model targets Anthropic."""
model = (self.model or "").lower()
return model.startswith("anthropic/") or model.startswith("claude-")
def _is_minimax_model(self) -> bool:
"""Return True when the configured model targets MiniMax."""
model = (self.model or "").lower()
@@ -1479,6 +1568,9 @@ class LiteLLMProvider(LLMProvider):
return
full_messages: list[dict[str, Any]] = []
if self._claude_code_oauth:
billing = _claude_code_billing_header(messages)
full_messages.append({"role": "system", "content": billing})
if system:
sys_msg: dict[str, Any] = {"role": "system", "content": system}
if _model_supports_cache_control(self.model):
@@ -1516,9 +1608,12 @@ class LiteLLMProvider(LLMProvider):
"messages": full_messages,
"max_tokens": max_tokens,
"stream": True,
"stream_options": {"include_usage": True},
**self.extra_kwargs,
}
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
# Only include it for providers that support it.
if not self._is_anthropic_model():
kwargs["stream_options"] = {"include_usage": True}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
+2
View File
@@ -45,6 +45,8 @@ class ToolResult:
tool_use_id: str
content: str
is_error: bool = False
image_content: list[dict[str, Any]] | None = None
is_skill_content: bool = False # AS-10: marks activated skill body, protected from pruning
class LLMProvider(ABC):
+16 -1
View File
@@ -30,6 +30,8 @@ from typing import Any
# ContextVar is thread-safe and async-safe - perfect for concurrent agent execution
trace_context: ContextVar[dict[str, Any] | None] = ContextVar("trace_context", default=None)
_STANDARD_LOG_RECORD_FIELDS = set(logging.makeLogRecord({}).__dict__)
# ANSI escape code pattern (matches \033[...m or \x1b[...m)
ANSI_ESCAPE_PATTERN = re.compile(r"\x1b\[[0-9;]*m|\033\[[0-9;]*m")
@@ -92,6 +94,14 @@ class StructuredFormatter(logging.Formatter):
if model is not None:
log_entry["model"] = model
# Preserve arbitrary structured fields passed via ``extra=...``.
for key, value in record.__dict__.items():
if key in _STANDARD_LOG_RECORD_FIELDS or key.startswith("_"):
continue
if key in log_entry:
continue
log_entry[key] = value
# Add exception info if present (strip ANSI codes from exception text too)
if record.exc_info:
exception_text = self.formatException(record.exc_info)
@@ -208,7 +218,12 @@ def configure_logging(
# Suppress noisy LiteLLM INFO logs (model/provider line + Provider List URL
# printed on every single completion call). Warnings and errors still show.
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
# Honour LITELLM_LOG env var so users can opt-in to debug output.
_litellm_level = os.getenv("LITELLM_LOG", "").upper()
if _litellm_level and hasattr(logging, _litellm_level):
logging.getLogger("LiteLLM").setLevel(getattr(logging, _litellm_level))
else:
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
# When in JSON mode, configure known third-party loggers to use JSON formatter
# This ensures libraries like LiteLLM, httpcore also output clean JSON
+2
View File
@@ -1,5 +1,6 @@
"""Agent Runner - load and run exported agents."""
from framework.runner.mcp_registry import MCPRegistry
from framework.runner.orchestrator import AgentOrchestrator
from framework.runner.protocol import (
AgentMessage,
@@ -17,6 +18,7 @@ __all__ = [
"AgentInfo",
"ValidationResult",
"ToolRegistry",
"MCPRegistry",
"tool",
# Multi-agent
"AgentOrchestrator",
+33 -4
View File
@@ -1561,6 +1561,22 @@ def _open_browser(url: str) -> None:
pass # Best-effort — don't crash if browser can't open
def _format_subprocess_output(output: str | bytes | None, limit: int = 2000) -> str:
"""Return subprocess output as trimmed text safe for console logging."""
if not output:
return ""
if isinstance(output, bytes):
text = output.decode(errors="replace")
else:
text = output
text = text.strip()
if len(text) <= limit:
return text
return text[-limit:]
def _build_frontend() -> bool:
"""Build the frontend if source is newer than dist. Returns True if dist exists."""
import subprocess
@@ -1596,18 +1612,25 @@ def _build_frontend() -> bool:
# Need to build
print("Building frontend...")
npm_cmd = "npm.cmd" if sys.platform == "win32" else "npm"
try:
# Incremental tsc caches can drift across branch changes and block builds.
for cache_file in frontend_dir.glob("tsconfig*.tsbuildinfo"):
cache_file.unlink(missing_ok=True)
# Ensure deps are installed
subprocess.run(
["npm", "install", "--no-fund", "--no-audit"],
[npm_cmd, "install", "--no-fund", "--no-audit"],
encoding="utf-8",
errors="replace",
cwd=frontend_dir,
check=True,
capture_output=True,
)
subprocess.run(
["npm", "run", "build"],
[npm_cmd, "run", "build"],
encoding="utf-8",
errors="replace",
cwd=frontend_dir,
check=True,
capture_output=True,
@@ -1618,8 +1641,14 @@ def _build_frontend() -> bool:
print("Node.js not found — skipping frontend build.")
return dist_dir.is_dir()
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode(errors="replace") if exc.stderr else ""
print(f"Frontend build failed: {stderr[:500]}")
stdout = _format_subprocess_output(exc.stdout)
stderr = _format_subprocess_output(exc.stderr)
cmd = " ".join(exc.cmd) if isinstance(exc.cmd, (list, tuple)) else str(exc.cmd)
details = "\n".join(part for part in [stdout, stderr] if part).strip()
if details:
print(f"Frontend build failed while running {cmd}:\n{details}")
else:
print(f"Frontend build failed while running {cmd} (exit {exc.returncode}).")
return dist_dir.is_dir()
+168 -20
View File
@@ -1,7 +1,7 @@
"""MCP Client for connecting to Model Context Protocol servers.
This module provides a client for connecting to MCP servers and invoking their tools.
Supports both STDIO and HTTP transports using the official MCP Python SDK.
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP Python SDK.
"""
import asyncio
@@ -22,7 +22,7 @@ class MCPServerConfig:
"""Configuration for an MCP server connection."""
name: str
transport: Literal["stdio", "http"]
transport: Literal["stdio", "http", "unix", "sse"]
# For STDIO transport
command: str | None = None
@@ -33,6 +33,7 @@ class MCPServerConfig:
# For HTTP transport
url: str | None = None
headers: dict[str, str] = field(default_factory=dict)
socket_path: str | None = None
# Optional metadata
description: str = ""
@@ -52,7 +53,7 @@ class MCPClient:
"""
Client for communicating with MCP servers.
Supports both STDIO and HTTP transports using the official MCP SDK.
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP SDK.
Manages the connection lifecycle and provides methods to list and invoke tools.
"""
@@ -68,6 +69,7 @@ class MCPClient:
self._read_stream = None
self._write_stream = None
self._stdio_context = None # Context manager for stdio_client
self._sse_context = None # Context manager for sse_client
self._errlog_handle = None # Track errlog file handle for cleanup
self._http_client: httpx.Client | None = None
self._tools: dict[str, MCPTool] = {}
@@ -141,6 +143,10 @@ class MCPClient:
self._connect_stdio()
elif self.config.transport == "http":
self._connect_http()
elif self.config.transport == "unix":
self._connect_unix()
elif self.config.transport == "sse":
self._connect_sse()
else:
raise ValueError(f"Unsupported transport: {self.config.transport}")
@@ -266,10 +272,94 @@ class MCPClient:
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
# Continue anyway, server might not have health endpoint
def _connect_unix(self) -> None:
"""Connect to MCP server via UNIX domain socket transport."""
if not self.config.url:
raise ValueError("url is required for UNIX transport")
if not self.config.socket_path:
raise ValueError("socket_path is required for UNIX transport")
self._http_client = httpx.Client(
base_url=self.config.url,
headers=self.config.headers,
timeout=30.0,
transport=httpx.HTTPTransport(uds=self.config.socket_path),
)
try:
response = self._http_client.get("/health")
response.raise_for_status()
logger.info(
"Connected to MCP server '%s' via UNIX socket at %s",
self.config.name,
self.config.socket_path,
)
except Exception as e:
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
# Continue anyway, server might not have health endpoint
def _connect_sse(self) -> None:
"""Connect to MCP server via SSE transport using MCP SDK with persistent session."""
if not self.config.url:
raise ValueError("url is required for SSE transport")
try:
loop_started = threading.Event()
connection_ready = threading.Event()
connection_error = []
def run_event_loop():
"""Run event loop in background thread."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
loop_started.set()
async def init_connection():
try:
from mcp import ClientSession
from mcp.client.sse import sse_client
self._sse_context = sse_client(
self.config.url,
headers=self.config.headers,
timeout=30.0,
)
(
self._read_stream,
self._write_stream,
) = await self._sse_context.__aenter__()
self._session = ClientSession(self._read_stream, self._write_stream)
await self._session.__aenter__()
await self._session.initialize()
connection_ready.set()
except Exception as e:
connection_error.append(e)
connection_ready.set()
self._loop.create_task(init_connection())
self._loop.run_forever()
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
self._loop_thread.start()
loop_started.wait(timeout=5)
if not loop_started.is_set():
raise RuntimeError("Event loop failed to start")
connection_ready.wait(timeout=10)
if connection_error:
raise connection_error[0]
logger.info(f"Connected to MCP server '{self.config.name}' via SSE")
except Exception as e:
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
def _discover_tools(self) -> None:
"""Discover available tools from the MCP server."""
try:
if self.config.transport == "stdio":
if self.config.transport in {"stdio", "sse"}:
tools_list = self._run_async(self._list_tools_stdio_async())
else:
tools_list = self._list_tools_http()
@@ -371,9 +461,37 @@ class MCPClient:
if self.config.transport == "stdio":
with self._stdio_call_lock:
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
elif self.config.transport == "sse":
return self._call_tool_with_retry(
lambda: self._run_async(self._call_tool_stdio_async(tool_name, arguments))
)
elif self.config.transport == "unix":
return self._call_tool_with_retry(lambda: self._call_tool_http(tool_name, arguments))
else:
return self._call_tool_http(tool_name, arguments)
def _call_tool_with_retry(self, call: Any) -> Any:
"""Retry transient MCP transport failures once after reconnecting."""
if self.config.transport == "stdio":
return call()
if self.config.transport not in {"unix", "sse"}:
return call()
try:
return call()
except (httpx.ConnectError, httpx.ReadTimeout) as original_error:
logger.warning(
"Retrying MCP tool call after transport error from '%s': %s",
self.config.name,
original_error,
)
self._reconnect()
try:
return call()
except (httpx.ConnectError, httpx.ReadTimeout) as retry_error:
raise original_error from retry_error
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via STDIO protocol using persistent session."""
if not self._session:
@@ -391,17 +509,30 @@ class MCPClient:
error_text = content_item.text
raise RuntimeError(f"MCP tool '{tool_name}' failed: {error_text}")
# Extract content
# Extract content — preserve image blocks alongside text
if result.content:
# MCP returns content as a list of content items
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, "text"):
return content_item.text
elif hasattr(content_item, "data"):
return content_item.data
return result.content
text_parts: list[str] = []
image_parts: list[dict[str, Any]] = []
for item in result.content:
if hasattr(item, "text"):
text_parts.append(item.text)
elif hasattr(item, "data") and hasattr(item, "mimeType"):
# MCP ImageContent — preserve as structured image block
image_parts.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{item.mimeType};base64,{item.data}",
},
}
)
elif hasattr(item, "data"):
text_parts.append(str(item.data))
text = "\n".join(text_parts) if text_parts else ""
if image_parts:
return {"_text": text, "_images": image_parts}
return text if text else None
return None
@@ -433,18 +564,24 @@ class MCPClient:
except Exception as e:
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
def _reconnect(self) -> None:
"""Reconnect to the configured MCP server."""
logger.info(f"Reconnecting to MCP server '{self.config.name}'...")
self.disconnect()
self.connect()
_CLEANUP_TIMEOUT = 10
_THREAD_JOIN_TIMEOUT = 12
async def _cleanup_stdio_async(self) -> None:
"""Async cleanup for STDIO session and context managers.
"""Async cleanup for persistent MCP session and context managers.
Cleanup order is critical:
- The session must be closed BEFORE the stdio_context because the session
depends on the streams provided by stdio_context.
- This mirrors the initialization order in _connect_stdio(), where
stdio_context is entered first (providing streams), then the session is
created with those streams and entered.
- The session must be closed BEFORE the transport context manager because the
session depends on the streams provided by that context.
- This mirrors the initialization order in _connect_stdio() / _connect_sse(),
where the transport context is entered first (providing streams), then the
session is created with those streams and entered.
- Do not change this ordering without carefully considering these dependencies.
"""
# First: close session (depends on stdio_context streams)
@@ -477,6 +614,16 @@ class MCPClient:
finally:
self._stdio_context = None
try:
if self._sse_context:
await self._sse_context.__aexit__(None, None, None)
except asyncio.CancelledError:
logger.debug("SSE context cleanup was cancelled; proceeding with best-effort shutdown")
except Exception as e:
logger.warning(f"Error closing SSE context: {e}")
finally:
self._sse_context = None
# Third: close errlog file handle if we opened one
if self._errlog_handle is not None:
try:
@@ -552,6 +699,7 @@ class MCPClient:
# Setting None to None is safe and ensures clean state.
self._session = None
self._stdio_context = None
self._sse_context = None
self._read_stream = None
self._write_stream = None
self._loop = None
@@ -0,0 +1,409 @@
"""Shared MCP client connection management."""
import logging
import threading
import httpx
from framework.runner.mcp_client import MCPClient, MCPServerConfig
logger = logging.getLogger(__name__)
_TRANSITION_TIMEOUT = 30.0
class MCPConnectionManager:
"""Process-wide MCP client pool keyed by server name."""
_instance = None
_lock = threading.Lock()
def __init__(self) -> None:
self._pool: dict[str, MCPClient] = {}
self._refcounts: dict[str, int] = {}
self._configs: dict[str, MCPServerConfig] = {}
self._pool_lock = threading.Lock()
self._transitions: dict[str, threading.Event] = {}
@classmethod
def get_instance(cls) -> "MCPConnectionManager":
"""Return the process-level singleton instance."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@staticmethod
def _is_connected(client: MCPClient | None) -> bool:
return bool(client and getattr(client, "_connected", False))
def has_connection(self, server_name: str) -> bool:
"""Return True when a live pooled connection exists for ``server_name``."""
with self._pool_lock:
return self._is_connected(self._pool.get(server_name))
def acquire(self, config: MCPServerConfig) -> MCPClient:
"""Get or create a shared connection and increment its refcount."""
server_name = config.name
while True:
should_connect = False
transition_event: threading.Event | None = None
with self._pool_lock:
client = self._pool.get(server_name)
if self._is_connected(client) and server_name not in self._transitions:
new_refcount = self._refcounts.get(server_name, 0) + 1
self._refcounts[server_name] = new_refcount
self._configs[server_name] = config
logger.debug(
"Reusing pooled connection for MCP server '%s' (refcount=%d)",
server_name,
new_refcount,
)
return client
transition_event = self._transitions.get(server_name)
if transition_event is None:
transition_event = threading.Event()
self._transitions[server_name] = transition_event
self._configs[server_name] = config
should_connect = True
if not should_connect:
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
logger.warning(
"Timed out waiting for transition on MCP server '%s', "
"forcing cleanup and retrying",
server_name,
)
with self._pool_lock:
stuck = self._transitions.get(server_name)
if stuck is transition_event:
self._transitions.pop(server_name, None)
transition_event.set()
continue
logger.info("Connecting to MCP server '%s'", server_name)
client = MCPClient(config)
try:
client.connect()
except Exception:
logger.warning(
"Failed to connect to MCP server '%s'",
server_name,
exc_info=True,
)
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._transitions.pop(server_name, None)
if (
server_name not in self._pool
and self._refcounts.get(server_name, 0) <= 0
):
self._configs.pop(server_name, None)
transition_event.set()
raise
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool[server_name] = client
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
self._configs[server_name] = config
self._transitions.pop(server_name, None)
transition_event.set()
logger.info(
"Connected to MCP server '%s' (refcount=1)",
server_name,
)
return client
# Lost the transition race, clean up and retry
try:
client.disconnect()
except Exception:
logger.debug(
"Error disconnecting stale client for '%s'",
server_name,
exc_info=True,
)
def release(self, server_name: str) -> None:
"""Decrement refcount and disconnect when the last user releases."""
while True:
disconnect_client: MCPClient | None = None
transition_event: threading.Event | None = None
should_disconnect = False
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
refcount = self._refcounts.get(server_name, 0)
if refcount <= 0:
return
if refcount > 1:
self._refcounts[server_name] = refcount - 1
logger.debug(
"Released MCP server '%s' (refcount=%d)",
server_name,
refcount - 1,
)
return
disconnect_client = self._pool.pop(server_name, None)
self._refcounts.pop(server_name, None)
self._configs.pop(server_name, None)
transition_event = threading.Event()
self._transitions[server_name] = transition_event
should_disconnect = True
if not should_disconnect:
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
logger.warning(
"Timed out waiting for transition on '%s' during release, forcing cleanup",
server_name,
)
with self._pool_lock:
stuck = self._transitions.get(server_name)
if stuck is transition_event:
self._transitions.pop(server_name, None)
transition_event.set()
continue
try:
if disconnect_client is not None:
disconnect_client.disconnect()
logger.info(
"Disconnected MCP server '%s' (last reference released)",
server_name,
)
except Exception:
logger.warning(
"Error disconnecting MCP server '%s' during release",
server_name,
exc_info=True,
)
finally:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._transitions.pop(server_name, None)
transition_event.set()
return
def health_check(self, server_name: str) -> bool:
"""Return True when the pooled connection appears healthy."""
while True:
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
client = self._pool.get(server_name)
config = self._configs.get(server_name)
break
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
logger.warning(
"Timed out waiting for transition on '%s' during health check",
server_name,
)
return False
if client is None or config is None:
return False
try:
match config.transport:
case "stdio":
client.list_tools()
return True
case "http":
if not config.url:
return False
with httpx.Client(
base_url=config.url,
headers=config.headers,
timeout=5.0,
) as http_client:
response = http_client.get("/health")
response.raise_for_status()
return True
case "sse":
client.list_tools()
return True
case "unix":
if not config.socket_path:
return False
with httpx.Client(
base_url=config.url or "http://localhost",
headers=config.headers,
timeout=5.0,
transport=httpx.HTTPTransport(uds=config.socket_path),
) as http_client:
response = http_client.get("/health")
response.raise_for_status()
return True
case _:
logger.warning(
"Unknown transport '%s' for health check on '%s'",
config.transport,
server_name,
)
return False
except Exception:
logger.debug(
"Health check failed for MCP server '%s'",
server_name,
exc_info=True,
)
return False
def reconnect(self, server_name: str) -> MCPClient:
"""Force a disconnect and replace the pooled client with a fresh one."""
while True:
transition_event: threading.Event | None = None
old_client: MCPClient | None = None
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
config = self._configs.get(server_name)
if config is None:
raise KeyError(f"Unknown MCP server: {server_name}")
old_client = self._pool.get(server_name)
transition_event = threading.Event()
self._transitions[server_name] = transition_event
break
if not transition_event.wait(timeout=_TRANSITION_TIMEOUT):
logger.warning(
"Timed out waiting for transition on '%s' during reconnect, forcing cleanup",
server_name,
)
with self._pool_lock:
stuck = self._transitions.get(server_name)
if stuck is transition_event:
self._transitions.pop(server_name, None)
transition_event.set()
# Disconnect old client safely
if old_client is not None:
try:
old_client.disconnect()
logger.info("Disconnected old client for '%s'", server_name)
except Exception:
logger.warning(
"Error disconnecting old client for '%s' during reconnect",
server_name,
exc_info=True,
)
logger.info("Reconnecting MCP server '%s'", server_name)
new_client = MCPClient(config)
try:
new_client.connect()
except Exception:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool.pop(server_name, None)
self._transitions.pop(server_name, None)
transition_event.set()
raise
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
current_refcount = self._refcounts.get(server_name, 0)
if current_refcount <= 0:
# All holders released during reconnect. Discard the
# new client instead of creating a phantom reference.
# Caller should acquire() fresh if needed.
self._transitions.pop(server_name, None)
transition_event.set()
logger.info(
"Reconnected MCP server '%s' but refcount dropped to 0, "
"discarding new client",
server_name,
)
try:
new_client.disconnect()
except Exception:
logger.debug(
"Error disconnecting discarded client for '%s'",
server_name,
exc_info=True,
)
raise KeyError(
f"MCP server '{server_name}' was fully released during reconnect"
)
self._pool[server_name] = new_client
self._configs[server_name] = config
self._refcounts[server_name] = current_refcount
self._transitions.pop(server_name, None)
transition_event.set()
logger.info(
"Reconnected MCP server '%s' (refcount=%d)",
server_name,
current_refcount,
)
return new_client
try:
new_client.disconnect()
except Exception:
logger.debug(
"Error disconnecting stale client for '%s' after reconnect race",
server_name,
exc_info=True,
)
return self.acquire(config)
def cleanup_all(self) -> None:
"""Disconnect all pooled clients and clear manager state."""
while True:
with self._pool_lock:
if self._transitions:
pending = list(self._transitions.values())
else:
cleanup_events = {name: threading.Event() for name in self._pool}
clients = list(self._pool.items())
self._transitions.update(cleanup_events)
self._pool.clear()
self._refcounts.clear()
self._configs.clear()
break
all_resolved = all(event.wait(timeout=_TRANSITION_TIMEOUT) for event in pending)
if not all_resolved:
logger.warning(
"Timed out waiting for pending transitions during cleanup, "
"forcing cleanup of stuck transitions",
)
with self._pool_lock:
for sn, evt in list(self._transitions.items()):
if not evt.is_set():
self._transitions.pop(sn, None)
evt.set()
logger.info("Cleaning up %d pooled MCP connections", len(clients))
for server_name, client in clients:
try:
client.disconnect()
logger.debug("Disconnected MCP server '%s' during cleanup", server_name)
except Exception:
logger.warning(
"Error disconnecting MCP server '%s' during cleanup",
server_name,
exc_info=True,
)
with self._pool_lock:
for server_name, event in cleanup_events.items():
current = self._transitions.get(server_name)
if current is event:
self._transitions.pop(server_name, None)
event.set()
+815
View File
@@ -0,0 +1,815 @@
"""MCP Server Registry: local state management for installed MCP servers."""
from __future__ import annotations
import json
import logging
import os
import tempfile
import tomllib
from datetime import UTC, datetime
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Any, Literal
import httpx
from framework.runner.mcp_client import MCPClient, MCPServerConfig
from framework.runner.mcp_connection_manager import MCPConnectionManager
logger = logging.getLogger(__name__)
DEFAULT_INDEX_URL = (
"https://raw.githubusercontent.com/aden-hive/hive-mcp-registry/main/registry_index.json"
)
DEFAULT_REFRESH_INTERVAL_HOURS = 24
_LAST_FETCHED_FILENAME = "last_fetched"
_LEGACY_LAST_FETCHED_FILENAME = "last_fetched.json"
_DEFAULT_CONFIG = {
"index_url": DEFAULT_INDEX_URL,
"refresh_interval_hours": DEFAULT_REFRESH_INTERVAL_HOURS,
}
class MCPRegistry:
"""Manages local MCP server state in ~/.hive/mcp_registry/."""
def __init__(self, base_path: Path | None = None):
self._base = base_path or Path.home() / ".hive" / "mcp_registry"
self._installed_path = self._base / "installed.json"
self._config_path = self._base / "config.json"
self._cache_dir = self._base / "cache"
# ── Initialization ──────────────────────────────────────────────
def initialize(self) -> None:
"""Create directory structure and default files if missing."""
self._base.mkdir(parents=True, exist_ok=True)
self._cache_dir.mkdir(parents=True, exist_ok=True)
if not self._config_path.exists():
self._write_json(self._config_path, _DEFAULT_CONFIG)
if not self._installed_path.exists():
self._write_json(self._installed_path, {"servers": {}})
# ── Internal I/O ────────────────────────────────────────────────
def _read_installed(self) -> dict:
"""Read installed.json, initializing if needed."""
if not self._installed_path.exists():
self.initialize()
return json.loads(self._installed_path.read_text(encoding="utf-8"))
def _write_installed(self, data: dict) -> None:
"""Write installed.json."""
self._write_json(self._installed_path, data)
def _read_config(self) -> dict:
"""Read config.json."""
if not self._config_path.exists():
self.initialize()
return json.loads(self._config_path.read_text(encoding="utf-8"))
def _read_cached_index(self) -> dict:
"""Read cached registry_index.json."""
index_path = self._cache_dir / "registry_index.json"
if not index_path.exists():
return {"servers": {}}
return json.loads(index_path.read_text(encoding="utf-8"))
def _get_effective_manifest(
self,
name: str,
entry: dict,
cached_index: dict | None = None,
) -> dict:
"""Return the manifest currently in effect for an installed entry."""
manifest = entry.get("manifest", {})
if entry.get("source") != "registry":
return manifest
index = cached_index or self._read_cached_index()
cached_manifest = index.get("servers", {}).get(name)
if cached_manifest is not None:
return cached_manifest
# Fall back to persisted manifest data when the cache is unavailable.
if isinstance(manifest, dict) and manifest:
return manifest
return {}
@staticmethod
def _write_json(path: Path, data: dict) -> None:
"""Write JSON to file atomically (write to temp, fsync, rename)."""
content = json.dumps(data, indent=2) + "\n"
fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(content)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, path)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
# ── add_local ───────────────────────────────────────────────────
def add_local(
self,
name: str,
transport: str | None = None,
manifest: dict | None = None,
url: str | None = None,
command: str | None = None,
args: list[str] | None = None,
env: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
cwd: str | None = None,
socket_path: str | None = None,
description: str = "",
) -> dict:
"""Register a local/running MCP server.
Can be called with an inline manifest dict, or with individual
transport/url/command params that build a manifest automatically.
"""
data = self._read_installed()
if name in data["servers"]:
raise ValueError(f"Server '{name}' already exists. Use remove first.")
if manifest is not None:
# Inline manifest provided directly
manifest = {**manifest, "name": name}
transport_config = manifest.get("transport", {})
transport = transport or transport_config.get("default", "stdio")
if "transport" not in manifest:
manifest["transport"] = {"supported": [transport], "default": transport}
else:
# Build manifest from individual params
if not transport:
raise ValueError("transport is required when manifest is not provided")
manifest = {
"name": name,
"description": description,
"transport": {"supported": [transport], "default": transport},
}
match transport:
case "http":
if not url:
raise ValueError("url is required for http transport")
manifest["http"] = {"url": url, "headers": headers or {}}
case "stdio":
if not command:
raise ValueError("command is required for stdio transport")
manifest["stdio"] = {
"command": command,
"args": args or [],
"env": env or {},
"cwd": cwd,
}
case "unix":
if not socket_path:
raise ValueError("socket_path is required for unix transport")
manifest["unix"] = {"socket_path": socket_path}
manifest["http"] = {"url": url or "http://localhost"}
case "sse":
if not url:
raise ValueError("url is required for sse transport")
manifest["sse"] = {"url": url}
case _:
raise ValueError(f"Unsupported transport: {transport}")
entry = self._make_entry(
source="local",
manifest=manifest,
transport=transport,
installed_by="hive mcp add",
)
data["servers"][name] = entry
self._write_installed(data)
logger.info("Registered local MCP server '%s' (%s)", name, transport)
return entry
# ── install ─────────────────────────────────────────────────────
def install(self, name: str, transport: str | None = None, version: str | None = None) -> dict:
"""Install a server from the cached remote registry index."""
data = self._read_installed()
if name in data["servers"]:
raise ValueError(f"Server '{name}' already exists. Remove it first or use update.")
index = self._read_cached_index()
manifest = index.get("servers", {}).get(name)
if manifest is None:
raise ValueError(
f"Server '{name}' not found in registry index. "
"Run 'hive mcp update' to refresh the index."
)
# Validate version if specified
if version is not None:
index_version = manifest.get("version")
if index_version is None:
raise ValueError(f"Cannot pin version for '{name}': manifest has no version field.")
if index_version != version:
raise ValueError(
f"Version mismatch for '{name}': requested {version}, "
f"index has {index_version}. "
"Run 'hive mcp update' to refresh the index."
)
transport_config = manifest.get("transport", {})
supported = transport_config.get("supported", [])
if transport is not None:
if supported and transport not in supported:
raise ValueError(
f"Transport '{transport}' not supported by '{name}'. Supported: {supported}"
)
resolved_transport = transport
else:
resolved_transport = transport_config.get("default", "stdio")
entry = self._make_entry(
source="registry",
manifest=self._make_registry_manifest_snapshot(name, manifest),
transport=resolved_transport,
installed_by="hive mcp install",
pinned=version is not None,
auto_update=version is None,
resolved_package_version=manifest.get("version"),
)
data["servers"][name] = entry
self._write_installed(data)
logger.info(
"Installed MCP server '%s' v%s from registry",
name,
entry["manifest_version"],
)
return entry
# ── remove / enable / disable ───────────────────────────────────
def remove(self, name: str) -> None:
"""Remove a server from the registry."""
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
del data["servers"][name]
self._write_installed(data)
logger.info("Removed MCP server '%s'", name)
def enable(self, name: str) -> None:
"""Enable a disabled server."""
self._set_enabled(name, enabled=True)
def disable(self, name: str) -> None:
"""Disable a server without removing it."""
self._set_enabled(name, enabled=False)
def _set_enabled(self, name: str, *, enabled: bool) -> None:
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
data["servers"][name]["enabled"] = enabled
self._write_installed(data)
logger.info("%s MCP server '%s'", "Enabled" if enabled else "Disabled", name)
# ── list / get ──────────────────────────────────────────────────
def list_installed(self) -> list[dict]:
"""Return all installed servers as a list of dicts with name included."""
data = self._read_installed()
return [{"name": name, **entry} for name, entry in data["servers"].items()]
def get_server(self, name: str) -> dict | None:
"""Get a single installed server entry by name, or None if not found."""
data = self._read_installed()
entry = data["servers"].get(name)
if entry is None:
return None
return {"name": name, **entry}
def list_available(self) -> list[dict]:
"""List all servers from cached remote index."""
index = self._read_cached_index()
return [{"name": name, **m} for name, m in index.get("servers", {}).items()]
# ── set_override ────────────────────────────────────────────────
def set_override(
self,
name: str,
key: str,
value: str,
override_type: Literal["env", "headers"] = "env",
) -> None:
"""Set an env or header override for a server."""
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
if override_type not in ("env", "headers"):
raise ValueError(f"Invalid override type: {override_type}")
data["servers"][name]["overrides"][override_type][key] = value
self._write_installed(data)
logger.info("Set %s override %s for MCP server '%s'", override_type, key, name)
# ── search ──────────────────────────────────────────────────────
def search(self, query: str) -> list[dict]:
"""Search registry index by name, tag, description, or tool name."""
query_lower = query.lower()
index = self._read_cached_index()
matches = []
for name, manifest in index.get("servers", {}).items():
if self._matches_query(name, manifest, query_lower):
matches.append({"name": name, **manifest})
return matches
@staticmethod
def _matches_query(name: str, manifest: dict, query: str) -> bool:
"""Check if a manifest matches a search query."""
if query in name.lower():
return True
description = manifest.get("description", "")
if query in description.lower():
return True
for tag in manifest.get("tags", []):
if query in tag.lower():
return True
for tool in manifest.get("tools", []):
tool_name = tool.get("name", "") if isinstance(tool, dict) else str(tool)
if query in tool_name.lower():
return True
return False
# ── update_index ────────────────────────────────────────────────
def is_index_stale(self) -> bool:
"""Check if the cached registry index needs refreshing."""
last_fetched_path = self._cache_dir / _LAST_FETCHED_FILENAME
legacy_path = self._cache_dir / _LEGACY_LAST_FETCHED_FILENAME
if not last_fetched_path.exists() and not legacy_path.exists():
return True
try:
path = last_fetched_path if last_fetched_path.exists() else legacy_path
data = json.loads(path.read_text(encoding="utf-8"))
last_fetched = datetime.fromisoformat(data["timestamp"])
config = self._read_config()
interval_hours = config.get("refresh_interval_hours", DEFAULT_REFRESH_INTERVAL_HOURS)
age_hours = (datetime.now(UTC) - last_fetched).total_seconds() / 3600
return age_hours >= interval_hours
except (KeyError, ValueError, OSError):
return True
def update_index(self) -> int:
"""Fetch the latest registry index from remote and cache it.
Returns the number of servers in the index.
"""
config = self._read_config()
url = config.get("index_url", DEFAULT_INDEX_URL)
response = httpx.get(url, timeout=10.0)
response.raise_for_status()
index = response.json()
self._write_json(self._cache_dir / "registry_index.json", index)
# Write last_fetched atomically too
self._write_json(
self._cache_dir / _LAST_FETCHED_FILENAME,
{"timestamp": datetime.now(UTC).isoformat()},
)
server_count = len(index.get("servers", {}))
logger.info("Updated registry index: %d servers available", server_count)
return server_count
# ── load_agent_selection ────────────────────────────────────────
def load_agent_selection(self, agent_path: Path) -> list[dict[str, Any]]:
"""Load mcp_registry.json from an agent directory and resolve servers.
Returns list of plain dicts compatible with ToolRegistry.register_mcp_server().
"""
registry_json_path = agent_path / "mcp_registry.json"
if not registry_json_path.exists():
return []
selection = json.loads(registry_json_path.read_text(encoding="utf-8"))
# Validate types at the JSON boundary. Bad fields are dropped with a
# warning so the agent still starts (graceful degradation).
expected_types: dict[str, type] = {
"include": list,
"tags": list,
"exclude": list,
"profile": str,
"max_tools": int,
"versions": dict,
}
validated: dict[str, Any] = {}
for field, expected in expected_types.items():
value = selection.get(field)
if value is None:
continue
if not isinstance(value, expected):
logger.warning(
"mcp_registry.json: '%s' must be %s, got %s; ignoring",
field,
expected.__name__,
type(value).__name__,
)
continue
validated[field] = value
configs = self.resolve_for_agent(
include=validated.get("include"),
tags=validated.get("tags"),
exclude=validated.get("exclude"),
profile=validated.get("profile"),
max_tools=validated.get("max_tools"),
versions=validated.get("versions"),
)
return [self._server_config_to_dict(c) for c in configs]
# ── resolve_for_agent ───────────────────────────────────────────
def resolve_for_agent(
self,
include: list[str] | None = None,
tags: list[str] | None = None,
exclude: list[str] | None = None,
profile: str | None = None,
max_tools: int | None = None,
versions: dict[str, str] | None = None,
) -> list[MCPServerConfig]:
"""Resolve installed servers matching agent selection criteria.
Selection precedence per PRD section 7.2:
1. profile expands to server names (union with include + tags)
2. include adds explicit servers
3. tags adds servers whose tags overlap
4. exclude removes (always wins)
5. Load order: include-order first, then alphabetical for tag/profile matches
Returns list of MCPServerConfig objects ready for ToolRegistry.
"""
data = self._read_installed()
servers = data.get("servers", {})
cached_index = self._read_cached_index()
exclude_set = set(exclude or [])
# Phase 1: collect profile-matched servers (alphabetical)
profile_matched: list[str] = []
if profile:
for name, entry in sorted(servers.items()):
if name in exclude_set:
continue
if profile == "all":
profile_matched.append(name)
else:
manifest = self._get_effective_manifest(name, entry, cached_index)
profiles = manifest.get("hive", {}).get("profiles", [])
if profile in profiles:
profile_matched.append(name)
# Phase 2: collect tag-matched servers (alphabetical)
tag_matched: list[str] = []
if tags:
tag_set = set(tags)
for name, entry in sorted(servers.items()):
if name in exclude_set:
continue
manifest = self._get_effective_manifest(name, entry, cached_index)
server_tags = set(manifest.get("tags", []))
if tag_set & server_tags:
tag_matched.append(name)
# Phase 3: build final ordered list
# include-order first, then alphabetical for profile/tag matches
selected: list[str] = []
seen: set[str] = set()
for name in include or []:
if name not in seen and name not in exclude_set:
selected.append(name)
seen.add(name)
for name in profile_matched:
if name not in seen:
selected.append(name)
seen.add(name)
for name in tag_matched:
if name not in seen:
selected.append(name)
seen.add(name)
# Build configs, tracking aggregate tool count for max_tools cap (FR-56)
configs: list[MCPServerConfig] = []
total_tools = 0
for name in selected:
entry = servers.get(name)
if entry is None:
logger.warning(
"Server '%s' requested but not installed. Run: hive mcp install %s",
name,
name,
)
continue
if not entry.get("enabled", True):
continue
manifest = self._get_effective_manifest(name, entry, cached_index)
# Check version pin (VC-6)
if versions and name in versions:
installed_version = entry.get("manifest_version", "0.0.0")
pinned_version = versions[name]
if installed_version != pinned_version:
logger.warning(
"Server '%s' version mismatch: installed=%s, pinned=%s. "
"Run: hive mcp update %s",
name,
installed_version,
pinned_version,
name,
)
continue
# Check tool count cap before adding (FR-56)
manifest_tools = manifest.get("tools", [])
server_tool_count = len(manifest_tools)
if max_tools is not None and server_tool_count == 0:
logger.debug(
"Server '%s' has no declared tools in manifest, skipping max_tools check",
name,
)
elif max_tools is not None and total_tools + server_tool_count > max_tools:
logger.info(
"Skipping server '%s' (%d tools): would exceed max_tools=%d",
name,
server_tool_count,
max_tools,
)
continue
config = self._manifest_to_server_config(
name,
manifest,
entry.get("overrides", {}),
transport_override=entry.get("transport"),
)
if config is not None:
configs.append(config)
total_tools += server_tool_count
return configs
def _manifest_to_server_config(
self,
name: str,
manifest: dict,
overrides: dict | None = None,
transport_override: str | None = None,
) -> MCPServerConfig | None:
"""Convert a manifest and overrides to MCPServerConfig."""
overrides = overrides or {}
transport_config = manifest.get("transport", {})
transport = transport_override or transport_config.get("default", "stdio")
description = manifest.get("description", "")
match transport:
case "stdio":
stdio_config = manifest.get("stdio", {})
merged_env = {
**stdio_config.get("env", {}),
**overrides.get("env", {}),
}
return MCPServerConfig(
name=name,
transport="stdio",
command=stdio_config.get("command"),
args=stdio_config.get("args", []),
env=merged_env,
cwd=stdio_config.get("cwd"),
description=description,
)
case "http":
http_config = manifest.get("http", {})
url = http_config.get("url", "")
merged_headers = {
**http_config.get("headers", {}),
**overrides.get("headers", {}),
}
return MCPServerConfig(
name=name,
transport="http",
url=url,
headers=merged_headers,
description=description,
)
case "unix":
unix_config = manifest.get("unix", {})
http_config = manifest.get("http", {})
merged_headers = {
**http_config.get("headers", {}),
**overrides.get("headers", {}),
}
return MCPServerConfig(
name=name,
transport="unix",
socket_path=unix_config.get("socket_path"),
url=http_config.get("url") or "http://localhost",
headers=merged_headers,
description=description,
)
case "sse":
sse_config = manifest.get("sse", {})
merged_headers = {
**sse_config.get("headers", {}),
**overrides.get("headers", {}),
}
return MCPServerConfig(
name=name,
transport="sse",
url=sse_config.get("url", ""),
headers=merged_headers,
description=description,
)
case _:
logger.warning(
"Unsupported transport '%s' for server '%s'",
transport,
name,
)
return None
@staticmethod
def _server_config_to_dict(config: MCPServerConfig) -> dict[str, Any]:
"""Convert MCPServerConfig to plain dict for ToolRegistry.register_mcp_server()."""
return {
"name": config.name,
"transport": config.transport,
"command": config.command,
"args": config.args,
"env": config.env,
"cwd": config.cwd,
"url": config.url,
"headers": config.headers,
"socket_path": config.socket_path,
"description": config.description,
}
# ── run_health_check ────────────────────────────────────────────
def health_check(self, name: str | None = None) -> dict | dict[str, dict]:
"""Check health of installed server(s). Updates telemetry fields.
If name is None, checks all installed servers and returns
a dict mapping server names to their health results.
"""
if name is None:
results = {}
for server in self.list_installed():
results[server["name"]] = self.health_check(server["name"])
return results
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
entry = data["servers"][name]
manifest = self._get_effective_manifest(name, entry)
config = self._manifest_to_server_config(
name,
manifest,
entry.get("overrides", {}),
transport_override=entry.get("transport"),
)
now = datetime.now(UTC).isoformat()
result: dict[str, Any] = {
"name": name,
"status": "unknown",
"tools": 0,
"error": None,
}
if config is None:
transport = entry.get("transport", "unknown")
result["status"] = "unhealthy"
result["error"] = f"Unsupported transport '{transport}'"
entry["last_health_status"] = "unhealthy"
entry["last_error"] = result["error"]
entry["last_health_check_at"] = now
self._write_installed(data)
return result
manager = MCPConnectionManager.get_instance()
try:
if manager.has_connection(name):
is_healthy = manager.health_check(name)
if not is_healthy:
raise RuntimeError("Shared MCP connection health check failed")
pooled_client = manager.acquire(config)
try:
tools = pooled_client.list_tools()
finally:
manager.release(name)
else:
with MCPClient(config) as client:
tools = client.list_tools()
result["status"] = "healthy"
result["tools"] = len(tools)
entry["last_health_status"] = "healthy"
entry["last_error"] = None
entry["last_validated_with_hive_version"] = self._get_hive_version()
except Exception as exc:
result["status"] = "unhealthy"
result["error"] = str(exc)
entry["last_health_status"] = "unhealthy"
entry["last_error"] = str(exc)
entry["last_health_check_at"] = now
self._write_installed(data)
return result
def run_health_check(self, name: str | None = None) -> dict | dict[str, dict]:
"""Backward-compatible wrapper for the public health_check API."""
return self.health_check(name)
@staticmethod
def _get_hive_version() -> str:
"""Get the current Hive version."""
try:
return version("framework")
except PackageNotFoundError:
project_toml = Path(__file__).resolve().parents[2] / "pyproject.toml"
if not project_toml.exists():
return "unknown"
try:
with project_toml.open("rb") as f:
data = tomllib.load(f)
return data.get("project", {}).get("version", "unknown")
except (tomllib.TOMLDecodeError, OSError):
return "unknown"
# ── helpers ──────────────────────────────────────────────────────
@staticmethod
def _make_entry(
*,
source: str,
manifest: dict,
transport: str,
installed_by: str,
pinned: bool = False,
auto_update: bool = False,
resolved_package_version: str | None = None,
) -> dict:
"""Build a standard installed server entry."""
now = datetime.now(UTC).isoformat()
return {
"source": source,
"manifest_version": manifest.get("version", "0.0.0"),
"manifest": manifest,
"installed_at": now,
"installed_by": installed_by,
"transport": transport,
"enabled": True,
"pinned": pinned,
"auto_update": auto_update,
"resolved_package_version": resolved_package_version,
"overrides": {"env": {}, "headers": {}},
"last_health_check_at": None,
"last_health_status": None,
"last_error": None,
"last_used_at": None,
"last_validated_with_hive_version": None,
}
@staticmethod
def _make_registry_manifest_snapshot(name: str, manifest: dict) -> dict[str, Any]:
"""Persist a full manifest snapshot for registry-installed servers."""
manifest_snapshot = dict(manifest)
manifest_snapshot["name"] = name
return manifest_snapshot
+4 -1
View File
@@ -1,6 +1,6 @@
"""Pre-load validation for agent graphs.
Runs structural and credential checks before MCP servers are spawned.
Runs structural, credential, and skill-trust checks before MCP servers are spawned.
Fails fast with actionable error messages.
"""
@@ -169,6 +169,9 @@ def run_preload_validation(
1. Graph structure (includes GCU subagent-only checks) non-recoverable
2. Credentials potentially recoverable via interactive setup
Skill discovery and trust gating (AS-13) happen later in runner._setup()
so they have access to agent-level skill configuration.
Raises PreloadValidationError for structural issues.
Raises CredentialError for credential issues.
"""
+398 -10
View File
@@ -552,6 +552,319 @@ def get_kimi_code_token() -> str | None:
return None
# ---------------------------------------------------------------------------
# Antigravity subscription token helpers
# ---------------------------------------------------------------------------
# Antigravity IDE (native macOS/Linux app) stores OAuth tokens in its
# VSCode-style SQLite state database under the key
# "antigravityUnifiedStateSync.oauthToken" as a base64-encoded protobuf blob.
ANTIGRAVITY_IDE_STATE_DB = (
Path.home()
/ "Library"
/ "Application Support"
/ "Antigravity"
/ "User"
/ "globalStorage"
/ "state.vscdb"
)
# Linux fallback for the IDE state DB
ANTIGRAVITY_IDE_STATE_DB_LINUX = (
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
)
# Antigravity credentials stored by native OAuth implementation
ANTIGRAVITY_AUTH_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
ANTIGRAVITY_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
_ANTIGRAVITY_TOKEN_LIFETIME_SECS = 3600 # Google access tokens expire in 1 hour
_ANTIGRAVITY_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
def _read_antigravity_ide_credentials() -> dict | None:
"""Read credentials from the Antigravity IDE's SQLite state database.
The Antigravity desktop IDE (VSCode-based) stores its OAuth token as a
base64-encoded protobuf blob in a SQLite database. The access token is
a standard Google OAuth ``ya29.*`` bearer token.
Returns:
Dict with ``accessToken`` and optionally ``refreshToken`` keys,
plus ``_source: "ide"`` to skip file-based save on refresh.
Returns None if the database is absent or the key is not found.
"""
import re
import sqlite3
for db_path in (ANTIGRAVITY_IDE_STATE_DB, ANTIGRAVITY_IDE_STATE_DB_LINUX):
if not db_path.exists():
continue
try:
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
try:
row = con.execute(
"SELECT value FROM ItemTable WHERE key = ?",
(_ANTIGRAVITY_IDE_STATE_DB_KEY,),
).fetchone()
finally:
con.close()
if not row:
continue
import base64
blob = base64.b64decode(row[0])
# The protobuf blob contains the access token (ya29.*) and
# refresh token (1//*) as length-prefixed UTF-8 strings.
# Decode the inner base64 layer and extract with regex.
inner_b64_candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
access_token: str | None = None
refresh_token: str | None = None
for candidate in inner_b64_candidates:
try:
padded = candidate + b"=" * (-len(candidate) % 4)
inner = base64.urlsafe_b64decode(padded)
except Exception:
continue
if not access_token:
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
if m:
access_token = m.group(0).decode("ascii")
if not refresh_token:
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
if m:
refresh_token = m.group(0).decode("ascii")
if access_token and refresh_token:
break
if access_token:
return {
"accounts": [
{
"accessToken": access_token,
"refreshToken": refresh_token or "",
}
],
"_source": "ide",
"_db_path": str(db_path),
}
except Exception as exc:
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
continue
return None
def _read_antigravity_credentials() -> dict | None:
"""Read Antigravity auth data from all supported credential sources.
Checks in order:
1. Antigravity IDE SQLite state database (native macOS/Linux app)
2. Native OAuth credentials file (~/.hive/antigravity-accounts.json)
Returns:
Auth data dict with an ``accounts`` list on success, None otherwise.
"""
# 1. Native Antigravity IDE (primary on macOS)
ide_creds = _read_antigravity_ide_credentials()
if ide_creds:
return ide_creds
# 2. Native OAuth credentials file
if ANTIGRAVITY_AUTH_FILE.exists():
try:
with open(ANTIGRAVITY_AUTH_FILE, encoding="utf-8") as f:
data = json.load(f)
accounts = data.get("accounts", [])
if accounts and isinstance(accounts[0], dict):
return data
except (json.JSONDecodeError, OSError):
pass
return None
def _is_antigravity_token_expired(auth_data: dict) -> bool:
"""Check whether the Antigravity access token is expired or near expiry.
For IDE-sourced credentials: uses the state DB's mtime as last_refresh
since the IDE keeps the DB fresh while it's running.
For JSON-sourced credentials: uses the ``last_refresh`` field or file mtime.
"""
import time
from datetime import datetime
now = time.time()
if auth_data.get("_source") == "ide":
# The IDE refreshes tokens automatically while running.
# Use the DB file's mtime as a proxy for when the token was last updated.
try:
db_path = Path(auth_data.get("_db_path", str(ANTIGRAVITY_IDE_STATE_DB)))
last_refresh: float = db_path.stat().st_mtime
except OSError:
return True
expires_at = last_refresh + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
last_refresh_val: float | str | None = auth_data.get("last_refresh")
if last_refresh_val is None:
try:
last_refresh_val = ANTIGRAVITY_AUTH_FILE.stat().st_mtime
except OSError:
return True
elif isinstance(last_refresh_val, str):
try:
last_refresh_val = datetime.fromisoformat(
last_refresh_val.replace("Z", "+00:00")
).timestamp()
except (ValueError, TypeError):
return True
expires_at = float(last_refresh_val) + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
def _refresh_antigravity_token(refresh_token: str) -> dict | None:
"""Refresh the Antigravity access token via Google OAuth.
POSTs form-encoded ``grant_type=refresh_token`` to the Google token
endpoint using Antigravity's public OAuth client ID.
Returns:
Parsed response dict (containing ``access_token``) on success,
None on any error.
"""
import urllib.error
import urllib.parse
import urllib.request
from framework.config import get_antigravity_client_id, get_antigravity_client_secret
client_id = get_antigravity_client_id()
client_secret = get_antigravity_client_secret()
params: dict = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
params["client_secret"] = client_secret
data = urllib.parse.urlencode(params).encode("utf-8")
req = urllib.request.Request(
ANTIGRAVITY_OAUTH_TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
return json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
logger.debug("Antigravity token refresh failed: %s", exc)
return None
def _save_refreshed_antigravity_credentials(auth_data: dict, token_data: dict) -> None:
"""Write refreshed tokens back to the Antigravity JSON credentials file.
Skipped for IDE-sourced credentials (the IDE manages its own DB).
Updates ``accounts[0].accessToken`` (and ``refreshToken`` if present),
then persists ``last_refresh`` as an ISO-8601 UTC string.
"""
from datetime import datetime
# IDE manages its own state — we do not write back to its SQLite DB
if auth_data.get("_source") == "ide":
return
try:
accounts = auth_data.get("accounts", [])
if not accounts:
return
account = accounts[0]
account["accessToken"] = token_data["access_token"]
if "refresh_token" in token_data:
account["refreshToken"] = token_data["refresh_token"]
auth_data["accounts"] = accounts
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
ANTIGRAVITY_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
fd = os.open(ANTIGRAVITY_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(auth_data, f, indent=2)
logger.debug("Antigravity credentials refreshed and saved")
except (OSError, KeyError) as exc:
logger.debug("Failed to save refreshed Antigravity credentials: %s", exc)
def get_antigravity_token() -> str | None:
"""Get the OAuth access token from an Antigravity subscription.
Credential sources checked in order:
1. Antigravity IDE SQLite state DB (native app, macOS/Linux)
2. antigravity-auth CLI JSON file
For IDE credentials the token is read directly (the IDE refreshes it
automatically while running). For JSON credentials an automatic OAuth
refresh is attempted when the token is near expiry.
Returns:
The ``ya29.*`` Google OAuth access token, or None if unavailable.
"""
auth_data = _read_antigravity_credentials()
if not auth_data:
return None
accounts = auth_data.get("accounts", [])
if not accounts:
return None
account = accounts[0]
access_token = account.get("accessToken")
if not access_token:
return None
if not _is_antigravity_token_expired(auth_data):
return access_token
# Token is expired or near expiry — attempt a refresh
refresh_token = account.get("refreshToken")
if not refresh_token:
logger.warning(
"Antigravity token expired and no refresh token available. "
"Re-open the Antigravity IDE to refresh, or run 'antigravity-auth accounts add'."
)
return access_token # return stale token; proxy may still accept it briefly
logger.info("Antigravity token expired or near expiry, refreshing...")
token_data = _refresh_antigravity_token(refresh_token)
if token_data and "access_token" in token_data:
_save_refreshed_antigravity_credentials(auth_data, token_data)
return token_data["access_token"]
logger.warning(
"Antigravity token refresh failed. "
"Re-open the Antigravity IDE or run 'antigravity-auth accounts add'."
)
return access_token
def _is_antigravity_proxy_available() -> bool:
"""Return True if antigravity-auth serve is running on localhost:8069."""
import socket
try:
with socket.create_connection(("localhost", 8069), timeout=0.5):
return True
except (OSError, TimeoutError):
return False
@dataclass
class AgentInfo:
"""Information about an exported agent."""
@@ -808,6 +1121,9 @@ class AgentRunner:
if mcp_config_path.exists():
self._load_mcp_servers_from_config(mcp_config_path)
# Auto-discover registry-selected MCP servers from mcp_registry.json
self._load_registry_mcp_servers(agent_path)
@staticmethod
def _import_agent_module(agent_path: Path):
"""Import an agent package from its directory path.
@@ -1111,6 +1427,45 @@ class AgentRunner:
"""Load and register MCP servers from a configuration file."""
self._tool_registry.load_mcp_config(config_path)
def _load_registry_mcp_servers(self, agent_path: Path) -> None:
"""Load and register MCP servers selected via ``mcp_registry.json``."""
from framework.runner.mcp_registry import MCPRegistry
try:
registry = MCPRegistry()
registry.initialize()
server_configs = registry.load_agent_selection(agent_path)
except Exception as exc:
logger.warning(
"Failed to load MCP registry servers for '%s': %s",
agent_path.name,
exc,
)
return
if not server_configs:
return
results = self._tool_registry.load_registry_servers(server_configs)
loaded = [result for result in results if result["status"] == "loaded"]
skipped = [result for result in results if result["status"] != "loaded"]
logger.info(
"Loaded %d/%d MCP registry server(s) for agent '%s'",
len(loaded),
len(results),
agent_path.name,
)
if skipped:
logger.info(
"Skipped MCP registry servers for agent '%s': %s",
agent_path.name,
[
{"server": result["server"], "reason": result["skipped_reason"]}
for result in skipped
],
)
def set_approval_callback(self, callback: Callable) -> None:
"""
Set a callback for human-in-the-loop approval during execution.
@@ -1141,7 +1496,10 @@ class AgentRunner:
# Create LLM provider
# Uses LiteLLM which auto-detects the provider from model name
if self.mock_mode:
# Skip if already injected (e.g. worker agents with a pre-built LLM)
if self._llm is not None:
pass # LLM already configured externally
elif self.mock_mode:
# Use mock LLM for testing without real API calls
from framework.llm.mock import MockLLMProvider
@@ -1155,6 +1513,7 @@ class AgentRunner:
use_claude_code = llm_config.get("use_claude_code_subscription", False)
use_codex = llm_config.get("use_codex_subscription", False)
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
use_antigravity = llm_config.get("use_antigravity_subscription", False)
api_base = llm_config.get("api_base")
api_key = None
@@ -1162,20 +1521,28 @@ class AgentRunner:
# Get OAuth token from Claude Code subscription
api_key = get_claude_code_token()
if not api_key:
print("Warning: Claude Code subscription configured but no token found.")
print("Run 'claude' to authenticate, then try again.")
logger.warning(
"Claude Code subscription configured but no token found. "
"Run 'claude' to authenticate, then try again."
)
elif use_codex:
# Get OAuth token from Codex subscription
api_key = get_codex_token()
if not api_key:
print("Warning: Codex subscription configured but no token found.")
print("Run 'codex' to authenticate, then try again.")
logger.warning(
"Codex subscription configured but no token found. "
"Run 'codex' to authenticate, then try again."
)
elif use_kimi_code:
# Get API key from Kimi Code CLI config (~/.kimi/config.toml)
api_key = get_kimi_code_token()
if not api_key:
print("Warning: Kimi Code subscription configured but no key found.")
print("Run 'kimi /login' to authenticate, then try again.")
logger.warning(
"Kimi Code subscription configured but no key found. "
"Run 'kimi /login' to authenticate, then try again."
)
elif use_antigravity:
pass # AntigravityProvider handles credentials internally
if api_key and use_claude_code:
# Use litellm's built-in Anthropic OAuth support.
@@ -1214,6 +1581,19 @@ class AgentRunner:
api_key=api_key,
api_base=api_base,
)
elif use_antigravity:
# Direct OAuth to Google's internal Cloud Code Assist gateway.
# No local proxy required — AntigravityProvider handles token
# refresh and Gemini-format request/response conversion natively.
from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415
provider = AntigravityProvider(model=self.model)
if not provider.has_credentials():
print(
"Warning: Antigravity credentials not found. "
"Run: uv run python core/antigravity_auth.py auth account add"
)
self._llm = provider
else:
# Local models (e.g. Ollama) don't need an API key
if self._is_local_model(self.model):
@@ -1245,8 +1625,12 @@ class AgentRunner:
if api_key_env:
os.environ[api_key_env] = api_key
elif api_key_env:
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
print(f"Set it with: export {api_key_env}=your-api-key")
logger.warning(
"%s not set. LLM calls will fail. "
"Set it with: export %s=your-api-key",
api_key_env,
api_key_env,
)
# Fail fast if the agent needs an LLM but none was configured
if self._llm is None:
@@ -1340,7 +1724,7 @@ class AgentRunner:
except Exception:
pass # Best-effort — agent works without account info
# Skill configuration — the runtime handles discovery, loading, and
# Skill configuration — the runtime handles discovery, loading, trust-gating and
# prompt rasterization. The runner just builds the config.
from framework.skills.config import SkillsConfig
from framework.skills.manager import SkillsManagerConfig
@@ -1351,6 +1735,7 @@ class AgentRunner:
skills=getattr(self, "_agent_skills", None),
),
project_root=self.agent_path,
interactive=self._interactive,
)
self._setup_agent_runtime(
@@ -1462,6 +1847,9 @@ class AgentRunner:
accounts_data: list[dict] | None = None,
tool_provider_map: dict[str, str] | None = None,
event_bus=None,
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
skills_manager_config=None,
) -> None:
"""Set up multi-entry-point execution using AgentRuntime."""
+148 -31
View File
@@ -16,6 +16,8 @@ from framework.llm.provider import Tool, ToolResult, ToolUse
logger = logging.getLogger(__name__)
_INPUT_LOG_MAX_LEN = 500
# Per-execution context overrides. Each asyncio task (and thus each
# concurrent graph execution) gets its own copy, so there are no races
# when multiple ExecutionStreams run in parallel.
@@ -54,6 +56,8 @@ class ToolRegistry:
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
self._mcp_clients: list[Any] = [] # List of MCPClient instances
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
# MCP resync tracking
@@ -243,6 +247,13 @@ class ToolRegistry:
def _wrap_result(tool_use_id: str, result: Any) -> ToolResult:
if isinstance(result, ToolResult):
return result
# MCP client returns dict with _images when image content is present
if isinstance(result, dict) and "_images" in result:
return ToolResult(
tool_use_id=tool_use_id,
content=result.get("_text", ""),
image_content=result["_images"],
)
return ToolResult(
tool_use_id=tool_use_id,
content=json.dumps(result) if not isinstance(result, str) else result,
@@ -269,6 +280,17 @@ class ToolRegistry:
r = await result
return _wrap_result(tool_use.id, r)
except Exception as exc:
inputs_str = json.dumps(tool_use.input, default=str)
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
logger.error(
"Async tool '%s' failed (tool_use_id=%s): %s\nInputs: %s",
tool_use.name,
tool_use.id,
exc,
inputs_str,
exc_info=True,
)
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps({"error": str(exc)}),
@@ -279,6 +301,17 @@ class ToolRegistry:
return _wrap_result(tool_use.id, result)
except Exception as e:
inputs_str = json.dumps(tool_use.input, default=str)
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
logger.error(
"Tool '%s' execution failed for tool_use_id=%s: %s\nInputs: %s",
tool_use.name,
tool_use.id,
e,
inputs_str,
exc_info=True,
)
return ToolResult(
tool_use_id=tool_use.id,
content=json.dumps({"error": str(e)}),
@@ -453,33 +486,85 @@ class ToolRegistry:
# Treat top-level keys as server names
server_list = [{"name": name, **cfg} for name, cfg in config.items()]
for server_config in server_list:
server_config = self._resolve_mcp_server_config(server_config, base_dir)
for _attempt in range(2):
try:
self.register_mcp_server(server_config)
break
except Exception as e:
name = server_config.get("name", "unknown")
if _attempt == 0:
logger.warning(
"MCP server '%s' failed to register, retrying in 2s: %s",
name,
e,
)
import time
time.sleep(2)
else:
logger.warning("MCP server '%s' failed after retry: %s", name, e)
resolved_server_list = [
self._resolve_mcp_server_config(server_config, base_dir)
for server_config in server_list
]
self.load_registry_servers(resolved_server_list, log_summary=False)
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
self._mcp_cred_snapshot = self._snapshot_credentials()
self._mcp_aden_key_snapshot = os.environ.get("ADEN_API_KEY")
def _register_mcp_server_with_retry(
self,
server_config: dict[str, Any],
) -> tuple[bool, int, str | None]:
"""Register a single MCP server with one retry for transient failures."""
name = server_config.get("name", "unknown")
last_error: str | None = None
for attempt in range(2):
try:
count = self.register_mcp_server(server_config)
if count > 0:
return True, count, None
last_error = "registered 0 tools"
except Exception as exc:
last_error = str(exc)
if attempt == 0:
logger.warning(
"MCP server '%s' failed to register, retrying in 2s: %s",
name,
last_error,
)
import time
time.sleep(2)
else:
logger.warning("MCP server '%s' failed after retry: %s", name, last_error)
return False, 0, last_error
def load_registry_servers(
self,
server_list: list[dict[str, Any]],
*,
log_summary: bool = True,
) -> list[dict[str, Any]]:
"""Register resolved registry-selected MCP servers with retry and status tracking."""
results: list[dict[str, Any]] = []
for server_config in server_list:
name = server_config.get("name", "unknown")
success, tools_loaded, error = self._register_mcp_server_with_retry(server_config)
result = {
"server": name,
"status": "loaded" if success else "skipped",
"tools_loaded": tools_loaded,
"skipped_reason": None if success else (error or "unknown error"),
}
results.append(result)
if log_summary:
logger.info(
"MCP registry server resolution",
extra={
"event": "mcp_registry_server_resolution",
"server": result["server"],
"status": result["status"],
"tools_loaded": result["tools_loaded"],
"skipped_reason": result["skipped_reason"],
},
)
return results
def register_mcp_server(
self,
server_config: dict[str, Any],
use_connection_manager: bool = True,
) -> int:
"""
Register an MCP server and discover its tools.
@@ -495,12 +580,14 @@ class ToolRegistry:
- url: Server URL (for http)
- headers: HTTP headers (for http)
- description: Server description (optional)
use_connection_manager: When True, reuse a shared client keyed by server name
Returns:
Number of tools registered from this server
"""
try:
from framework.runner.mcp_client import MCPClient, MCPServerConfig
from framework.runner.mcp_connection_manager import MCPConnectionManager
# Build config object
config = MCPServerConfig(
@@ -512,15 +599,23 @@ class ToolRegistry:
cwd=server_config.get("cwd"),
url=server_config.get("url"),
headers=server_config.get("headers", {}),
socket_path=server_config.get("socket_path"),
description=server_config.get("description", ""),
)
# Create and connect client
client = MCPClient(config)
client.connect()
if use_connection_manager:
client = MCPConnectionManager.get_instance().acquire(config)
else:
client = MCPClient(config)
client.connect()
# Store client for cleanup
self._mcp_clients.append(client)
client_id = id(client)
self._mcp_client_servers[client_id] = config.name
if use_connection_manager:
self._mcp_managed_clients.add(client_id)
# Register each tool
server_name = server_config["name"]
@@ -560,14 +655,25 @@ class ToolRegistry:
}
merged_inputs = {**clean_inputs, **filtered_context}
result = client_ref.call_tool(tool_name, merged_inputs)
# MCP tools return content array, extract the result
# MCP client already extracts content (returns str
# or {"_text": ..., "_images": ...} for image results).
# Handle legacy list format from HTTP transport.
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "text" in result[0]:
return result[0]["text"]
return result[0]
return result
except Exception as e:
logger.error(f"MCP tool '{tool_name}' execution failed: {e}")
inputs_str = json.dumps(inputs, default=str)
if len(inputs_str) > _INPUT_LOG_MAX_LEN:
inputs_str = inputs_str[:_INPUT_LOG_MAX_LEN] + "...(truncated)"
logger.error(
"MCP tool '%s' execution failed: %s\nInputs: %s",
tool_name,
e,
inputs_str,
exc_info=True,
)
return {"error": str(e)}
return executor
@@ -720,12 +826,7 @@ class ToolRegistry:
logger.info("%s — resyncing MCP servers", reason)
# 1. Disconnect existing MCP clients
for client in self._mcp_clients:
try:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client during resync: {e}")
self._mcp_clients.clear()
self._cleanup_mcp_clients("during resync")
# 2. Remove MCP-registered tools
for name in self._mcp_tool_names:
@@ -740,12 +841,28 @@ class ToolRegistry:
def cleanup(self) -> None:
"""Clean up all MCP client connections."""
self._cleanup_mcp_clients()
def _cleanup_mcp_clients(self, context: str = "") -> None:
"""Disconnect or release all tracked MCP clients for this registry."""
if context:
context = f" {context}"
for client in self._mcp_clients:
client_id = id(client)
server_name = self._mcp_client_servers.get(client_id, client.config.name)
try:
client.disconnect()
if client_id in self._mcp_managed_clients:
from framework.runner.mcp_connection_manager import MCPConnectionManager
MCPConnectionManager.get_instance().release(server_name)
else:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client: {e}")
logger.warning(f"Error disconnecting MCP client{context}: {e}")
self._mcp_clients.clear()
self._mcp_client_servers.clear()
self._mcp_managed_clients.clear()
def __del__(self):
"""Destructor to ensure cleanup."""
+22 -2
View File
@@ -137,6 +137,7 @@ class AgentRuntime:
# Deprecated — pass skills_manager_config instead.
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
):
"""
Initialize agent runtime.
@@ -158,6 +159,9 @@ class AgentRuntime:
event_bus: Optional external EventBus. If provided, the runtime shares
this bus instead of creating its own. Used by SessionManager to
share a single bus between queen, worker, and judge.
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
skills_manager_config: Skill configuration the runtime owns
discovery, loading, and prompt renderation internally.
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
@@ -195,6 +199,8 @@ class AgentRuntime:
self._skills_manager = SkillsManager()
self._skills_manager.load()
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
# Primary graph identity
self._graph_id: str = graph_id or "primary"
@@ -341,6 +347,7 @@ class AgentRuntime:
tool_provider_map=self._tool_provider_map,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
)
await stream.start()
self._streams[ep_id] = stream
@@ -977,6 +984,7 @@ class AgentRuntime:
tool_provider_map=self._tool_provider_map,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
)
if self._running:
await stream.start()
@@ -1466,6 +1474,7 @@ class AgentRuntime:
graph_id: str | None = None,
*,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> bool:
"""Inject user input into a running client-facing node.
@@ -1478,6 +1487,8 @@ class AgentRuntime:
graph_id: Optional graph to search first (defaults to active graph)
is_client_input: True when the message originates from a real
human user (e.g. /chat endpoint), False for external events.
image_content: Optional list of image content blocks (OpenAI
image_url format) to include alongside the text.
Returns:
True if input was delivered, False if no matching node found
@@ -1489,7 +1500,9 @@ class AgentRuntime:
target = graph_id or self._active_graph_id
if target in self._graphs:
for stream in self._graphs[target].streams.values():
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
if await stream.inject_input(
node_id, content, is_client_input=is_client_input, image_content=image_content
):
return True
# Then search all other graphs
@@ -1497,7 +1510,9 @@ class AgentRuntime:
if gid == target:
continue
for stream in reg.streams.values():
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
if await stream.inject_input(
node_id, content, is_client_input=is_client_input, image_content=image_content
):
return True
return False
@@ -1760,6 +1775,7 @@ def create_agent_runtime(
# Deprecated — pass skills_manager_config instead.
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
) -> AgentRuntime:
"""
Create and configure an AgentRuntime with entry points.
@@ -1786,6 +1802,9 @@ def create_agent_runtime(
accounts_data: Raw account data for per-node prompt generation.
tool_provider_map: Tool name to provider name mapping for account routing.
event_bus: Optional external EventBus to share with other components.
skills_catalog_prompt: Available skills catalog for system prompt.
protocols_prompt: Default skill operational protocols for system prompt.
skill_dirs: Skill base directories for Tier 3 resource access.
skills_manager_config: Skill configuration the runtime owns
discovery, loading, and prompt renderation internally.
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
@@ -1819,6 +1838,7 @@ def create_agent_runtime(
skills_manager_config=skills_manager_config,
skills_catalog_prompt=skills_catalog_prompt,
protocols_prompt=protocols_prompt,
skill_dirs=skill_dirs,
)
for spec in entry_points:
+3 -2
View File
@@ -117,6 +117,7 @@ class EventType(StrEnum):
# Context management
CONTEXT_COMPACTED = "context_compacted"
CONTEXT_USAGE_UPDATED = "context_usage_updated"
# External triggers
WEBHOOK_RECEIVED = "webhook_received"
@@ -534,8 +535,8 @@ class EventBus:
async with self._semaphore:
try:
await handler(event)
except Exception as e:
logger.error(f"Handler error for {event.type}: {e}")
except Exception:
logger.exception(f"Handler error for {event.type}")
# Run all handlers concurrently
await asyncio.gather(*[run_handler(h) for h in handlers], return_exceptions=True)
+8 -1
View File
@@ -188,6 +188,7 @@ class ExecutionStream:
tool_provider_map: dict[str, str] | None = None,
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
):
"""
Initialize execution stream.
@@ -213,6 +214,7 @@ class ExecutionStream:
tool_provider_map: Tool name to provider name mapping for account routing
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
"""
self.stream_id = stream_id
self.entry_spec = entry_spec
@@ -236,6 +238,7 @@ class ExecutionStream:
self._tool_provider_map = tool_provider_map
self._skills_catalog_prompt = skills_catalog_prompt
self._protocols_prompt = protocols_prompt
self._skill_dirs: list[str] = skill_dirs or []
_es_logger = logging.getLogger(__name__)
if protocols_prompt:
@@ -430,6 +433,7 @@ class ExecutionStream:
content: str,
*,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> bool:
"""Inject user input into a running client-facing EventLoopNode.
@@ -441,7 +445,9 @@ class ExecutionStream:
for executor in self._active_executors.values():
node = executor.node_registry.get(node_id)
if node is not None and hasattr(node, "inject_event"):
await node.inject_event(content, is_client_input=is_client_input)
await node.inject_event(
content, is_client_input=is_client_input, image_content=image_content
)
return True
return False
@@ -696,6 +702,7 @@ class ExecutionStream:
tool_provider_map=self._tool_provider_map,
skills_catalog_prompt=self._skills_catalog_prompt,
protocols_prompt=self._protocols_prompt,
skill_dirs=self._skill_dirs,
)
# Track executor so inject_input() can reach EventLoopNode instances
self._active_executors[execution_id] = executor
@@ -8,6 +8,7 @@ write. Errors are silently swallowed — this must never break the agent.
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import IO, Any
@@ -47,6 +48,9 @@ def log_llm_turn(
Never raises.
"""
try:
# Skip logging during test runs to avoid polluting real logs.
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("HIVE_DISABLE_LLM_LOGS"):
return
global _log_file, _log_ready # noqa: PLW0603
if not _log_ready:
_log_file = _open_log()
@@ -62,6 +62,7 @@ async def create_queen(
from framework.agents.queen.nodes.thinking_hook import select_expert_persona
from framework.graph.event_loop_node import HookContext, HookResult
from framework.graph.executor import GraphExecutor
from framework.runner.mcp_registry import MCPRegistry
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.core import Runtime
from framework.runtime.event_bus import AgentEvent, EventType
@@ -86,6 +87,16 @@ async def create_queen(
except Exception:
logger.warning("Queen: MCP config failed to load", exc_info=True)
try:
registry = MCPRegistry()
registry.initialize()
registry_configs = registry.load_agent_selection(queen_pkg_dir)
if registry_configs:
results = queen_registry.load_registry_servers(registry_configs)
logger.info("Queen: loaded MCP registry servers: %s", results)
except Exception:
logger.warning("Queen: MCP registry config failed to load", exc_info=True)
# ---- Phase state --------------------------------------------------
initial_phase = "staging" if worker_identity else "planning"
phase_state = QueenPhaseState(phase=initial_phase, event_bus=session.event_bus)
+1
View File
@@ -37,6 +37,7 @@ DEFAULT_EVENT_TYPES = [
EventType.NODE_RETRY,
EventType.NODE_TOOL_DOOM_LOOP,
EventType.CONTEXT_COMPACTED,
EventType.CONTEXT_USAGE_UPDATED,
EventType.WORKER_LOADED,
EventType.CREDENTIALS_REQUIRED,
EventType.SUBAGENT_REPORT,
+11 -4
View File
@@ -108,7 +108,10 @@ async def handle_chat(request: web.Request) -> web.Response:
The input box is permanently connected to the queen agent.
Worker input is handled separately via /worker-input.
Body: {"message": "hello"}
Body: {"message": "hello", "images": [{"type": "image_url", "image_url": {"url": "data:..."}}]}
The optional ``images`` field accepts a list of OpenAI-format image_url
content blocks. The frontend encodes images as base64 data URIs.
"""
session, err = resolve_session(request)
if err:
@@ -116,15 +119,16 @@ async def handle_chat(request: web.Request) -> web.Response:
body = await request.json()
message = body.get("message", "")
image_content = body.get("images") or None # list[dict] | None
if not message:
if not message and not image_content:
return web.json_response({"error": "message is required"}, status=400)
queen_executor = session.queen_executor
if queen_executor is not None:
node = queen_executor.node_registry.get("queen")
if node is not None and hasattr(node, "inject_event"):
await node.inject_event(message, is_client_input=True)
await node.inject_event(message, is_client_input=True, image_content=image_content)
# Publish to EventBus so the session event log captures user messages
from framework.runtime.event_bus import AgentEvent, EventType
@@ -134,7 +138,10 @@ async def handle_chat(request: web.Request) -> web.Response:
stream_id="queen",
node_id="queen",
execution_id=session.id,
data={"content": message},
data={
"content": message,
"image_count": len(image_content) if image_content else 0,
},
)
)
return web.json_response(
+31 -56
View File
@@ -11,7 +11,6 @@ Session-primary routes:
- GET /api/sessions/{session_id}/entry-points list entry points
- PATCH /api/sessions/{session_id}/triggers/{id} update trigger task
- GET /api/sessions/{session_id}/graphs list graph IDs
- GET /api/sessions/{session_id}/queen-messages queen conversation history
- GET /api/sessions/{session_id}/events/history persisted eventbus log (for replay)
Worker session browsing (persisted execution runs on disk):
@@ -29,6 +28,8 @@ import contextlib
import json
import logging
import shutil
import subprocess
import sys
import time
from pathlib import Path
@@ -52,8 +53,11 @@ def _get_manager(request: web.Request) -> SessionManager:
def _session_to_live_dict(session) -> dict:
"""Serialize a live Session to the session-primary JSON shape."""
from framework.llm.capabilities import supports_image_tool_results
info = session.worker_info
phase_state = getattr(session, "phase_state", None)
queen_model: str = getattr(getattr(session, "runner", None), "model", "") or ""
return {
"session_id": session.id,
"worker_id": session.worker_id,
@@ -69,6 +73,7 @@ def _session_to_live_dict(session) -> dict:
"queen_phase": phase_state.phase
if phase_state
else ("staging" if session.worker_runtime else "planning"),
"queen_supports_images": supports_image_tool_results(queen_model) if queen_model else True,
}
@@ -862,60 +867,6 @@ async def handle_messages(request: web.Request) -> web.Response:
return web.json_response({"messages": all_messages})
async def handle_queen_messages(request: web.Request) -> web.Response:
"""GET /api/sessions/{session_id}/queen-messages — get queen conversation.
Reads directly from disk so it works for both live sessions and cold
(post-server-restart) sessions no live session required.
"""
session_id = request.match_info["session_id"]
queen_dir = Path.home() / ".hive" / "queen" / "session" / session_id
convs_dir = queen_dir / "conversations"
if not convs_dir.exists():
return web.json_response({"messages": [], "session_id": session_id})
all_messages: list[dict] = []
def _read_parts(parts_dir: Path, node_id: str) -> None:
if not parts_dir.exists():
return
for part_file in sorted(parts_dir.iterdir()):
if part_file.suffix != ".json":
continue
try:
part = json.loads(part_file.read_text(encoding="utf-8"))
part["_node_id"] = node_id
# Use file mtime as created_at so frontend can order
# queen and worker messages chronologically.
part.setdefault("created_at", part_file.stat().st_mtime)
all_messages.append(part)
except (json.JSONDecodeError, OSError):
continue
# Flat layout: conversations/parts/*.json
_read_parts(convs_dir / "parts", "queen")
# Node-based layout: conversations/<node_id>/parts/*.json
for node_dir in convs_dir.iterdir():
if not node_dir.is_dir() or node_dir.name == "parts":
continue
_read_parts(node_dir / "parts", node_dir.name)
all_messages.sort(key=lambda m: m.get("created_at", m.get("seq", 0)))
# Filter to client-facing messages only
all_messages = [
m
for m in all_messages
if not m.get("is_transition_marker")
and m["role"] != "tool"
and not (m["role"] == "assistant" and m.get("tool_calls"))
]
return web.json_response({"messages": all_messages, "session_id": session_id})
async def handle_session_events_history(request: web.Request) -> web.Response:
"""GET /api/sessions/{session_id}/events/history — persisted eventbus log.
@@ -1033,6 +984,29 @@ async def handle_discover(request: web.Request) -> web.Response:
return web.json_response(result)
async def handle_reveal_session_folder(request: web.Request) -> web.Response:
"""POST /api/sessions/{session_id}/reveal — open session data folder in the OS file manager."""
manager: SessionManager = request.app["manager"]
session_id = request.match_info["session_id"]
session = manager.get_session(session_id)
storage_session_id = (session.queen_resume_from or session.id) if session else session_id
folder = Path.home() / ".hive" / "queen" / "session" / storage_session_id
folder.mkdir(parents=True, exist_ok=True)
try:
if sys.platform == "darwin":
subprocess.Popen(["open", str(folder)])
elif sys.platform == "win32":
subprocess.Popen(["explorer", str(folder)])
else:
subprocess.Popen(["xdg-open", str(folder)])
except Exception as exc:
return web.json_response({"error": str(exc)}, status=500)
return web.json_response({"path": str(folder)})
# ------------------------------------------------------------------
# Route registration
# ------------------------------------------------------------------
@@ -1057,13 +1031,14 @@ def register_routes(app: web.Application) -> None:
app.router.add_delete("/api/sessions/{session_id}/worker", handle_unload_worker)
# Session info
app.router.add_post("/api/sessions/{session_id}/reveal", handle_reveal_session_folder)
app.router.add_get("/api/sessions/{session_id}/stats", handle_session_stats)
app.router.add_get("/api/sessions/{session_id}/entry-points", handle_session_entry_points)
app.router.add_patch(
"/api/sessions/{session_id}/triggers/{trigger_id}", handle_update_trigger_task
)
app.router.add_get("/api/sessions/{session_id}/graphs", handle_session_graphs)
app.router.add_get("/api/sessions/{session_id}/queen-messages", handle_queen_messages)
app.router.add_get("/api/sessions/{session_id}/events/history", handle_session_events_history)
# Worker session browsing (session-primary)
+74 -16
View File
@@ -96,8 +96,7 @@ class SessionManager:
Internal helper use create_session() or create_session_with_worker().
"""
from framework.config import RuntimeConfig
from framework.llm.litellm import LiteLLMProvider
from framework.config import RuntimeConfig, get_hive_config
from framework.runtime.event_bus import EventBus
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -111,12 +110,20 @@ class SessionManager:
rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model)
# Session owns these — shared with queen and worker
llm = LiteLLMProvider(
model=rc.model,
api_key=rc.api_key,
api_base=rc.api_base,
**rc.extra_kwargs,
)
llm_config = get_hive_config().get("llm", {})
if llm_config.get("use_antigravity_subscription"):
from framework.llm.antigravity import AntigravityProvider
llm = AntigravityProvider(model=rc.model)
else:
from framework.llm.litellm import LiteLLMProvider
llm = LiteLLMProvider(
model=rc.model,
api_key=rc.api_key,
api_base=rc.api_base,
**rc.extra_kwargs,
)
event_bus = EventBus()
session = Session(
@@ -287,7 +294,17 @@ class SessionManager:
try:
# Blocking I/O — load in executor
loop = asyncio.get_running_loop()
resolved_model = model or self._model
# Prioritize: explicit model arg > worker-specific model > session default
from framework.config import (
get_preferred_worker_model,
get_worker_api_base,
get_worker_api_key,
get_worker_llm_extra_kwargs,
)
worker_model = get_preferred_worker_model()
resolved_model = model or worker_model or self._model
runner = await loop.run_in_executor(
None,
lambda: AgentRunner.load(
@@ -299,6 +316,30 @@ class SessionManager:
),
)
# If a worker-specific model is configured, build an LLM provider
# with the correct worker credentials so _setup() doesn't fall back
# to the queen's llm config (which may be a different provider).
if worker_model and not model:
from framework.config import get_hive_config
worker_llm_cfg = get_hive_config().get("worker_llm", {})
if worker_llm_cfg.get("use_antigravity_subscription"):
from framework.llm.antigravity import AntigravityProvider
runner._llm = AntigravityProvider(model=resolved_model)
else:
from framework.llm.litellm import LiteLLMProvider
worker_api_key = get_worker_api_key()
worker_api_base = get_worker_api_base()
worker_extra = get_worker_llm_extra_kwargs()
runner._llm = LiteLLMProvider(
model=resolved_model,
api_key=worker_api_key,
api_base=worker_api_base,
**worker_extra,
)
# Setup with session's event bus
if runner._agent_runtime is None:
await loop.run_in_executor(
@@ -793,10 +834,11 @@ class SessionManager:
exec_id = event.execution_id
if event.type == _ET.EXECUTION_STARTED:
# New run on this execution_id — reset cooldown so the first
# iteration always produces a mid-run snapshot.
# New run on this execution_id — start the cooldown timer so
# mid-run snapshots don't fire immediately at session start.
# The first snapshot will happen after _DIGEST_COOLDOWN seconds.
if exec_id:
_last_digest.pop(exec_id, None)
_last_digest[exec_id] = _time.monotonic()
elif event.type in (
_ET.EXECUTION_COMPLETED,
@@ -923,6 +965,7 @@ class SessionManager:
# then use max+1 as offset so resumed sessions produce monotonically
# increasing iteration values — preventing frontend message ID collisions.
iteration_offset = 0
last_phase = ""
events_path = queen_dir / "events.jsonl"
try:
if events_path.exists():
@@ -934,17 +977,25 @@ class SessionManager:
continue
try:
evt = json.loads(line)
it = evt.get("data", {}).get("iteration")
data = evt.get("data", {})
it = data.get("iteration")
if isinstance(it, int) and it > max_iter:
max_iter = it
# Track the latest queen phase from QUEEN_PHASE_CHANGED events
if evt.get("type") == "queen_phase_changed":
phase = data.get("phase")
if phase:
last_phase = phase
except (json.JSONDecodeError, TypeError):
continue
if max_iter >= 0:
iteration_offset = max_iter + 1
logger.info(
"Session '%s' resuming with iteration_offset=%d (from events.jsonl max)",
"Session '%s' resuming with iteration_offset=%d"
" (from events.jsonl max), last phase: %s",
session.id,
iteration_offset,
last_phase or "unknown",
)
except OSError:
pass
@@ -996,10 +1047,17 @@ class SessionManager:
_consolidation_session_dir = queen_dir
async def _on_compaction(_event) -> None:
# Only consolidate on queen compactions — worker and subagent
# compactions are frequent and don't warrant a memory update.
if getattr(_event, "stream_id", None) != "queen":
return
from framework.agents.queen.queen_memory import consolidate_queen_memory
await consolidate_queen_memory(
session.id, _consolidation_session_dir, _consolidation_llm
asyncio.create_task(
consolidate_queen_memory(
session.id, _consolidation_session_dir, _consolidation_llm
),
name=f"queen-memory-consolidation-{session.id}",
)
from framework.runtime.event_bus import EventType as _ET
+11 -2
View File
@@ -1,8 +1,8 @@
"""Hive Agent Skills — discovery, parsing, and injection of SKILL.md packages.
"""Hive Agent Skills — discovery, parsing, trust gating, and injection of SKILL.md packages.
Implements the open Agent Skills standard (agentskills.io) for portable
skill discovery and activation, plus built-in default skills for runtime
operational discipline.
operational discipline, and AS-13 trust gating for project-scope skills.
"""
from framework.skills.catalog import SkillCatalog
@@ -10,7 +10,10 @@ from framework.skills.config import DefaultSkillConfig, SkillsConfig
from framework.skills.defaults import DefaultSkillManager
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
from framework.skills.manager import SkillsManager, SkillsManagerConfig
from framework.skills.models import TrustStatus
from framework.skills.parser import ParsedSkill, parse_skill_md
from framework.skills.skill_errors import SkillError, SkillErrorCode, log_skill_error
from framework.skills.trust import TrustedRepoStore, TrustGate
__all__ = [
"DefaultSkillConfig",
@@ -22,5 +25,11 @@ __all__ = [
"SkillsConfig",
"SkillsManager",
"SkillsManagerConfig",
"TrustGate",
"TrustedRepoStore",
"TrustStatus",
"parse_skill_md",
"SkillError",
"SkillErrorCode",
"log_skill_error",
]
+10 -1
View File
@@ -10,6 +10,7 @@ import logging
from xml.sax.saxutils import escape
from framework.skills.parser import ParsedSkill
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
@@ -76,6 +77,7 @@ class SkillCatalog:
lines.append(f" <name>{escape(skill.name)}</name>")
lines.append(f" <description>{escape(skill.description)}</description>")
lines.append(f" <location>{escape(skill.location)}</location>")
lines.append(f" <base_dir>{escape(skill.base_dir)}</base_dir>")
lines.append(" </skill>")
lines.append("</available_skills>")
@@ -96,7 +98,14 @@ class SkillCatalog:
for name in skill_names:
skill = self.get(name)
if skill is None:
logger.warning("Pre-activated skill '%s' not found in catalog", name)
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_NOT_FOUND,
what=f"Pre-activated skill '{name}' not found in catalog",
why="The skill was listed for pre-activation but was not discovered.",
fix=f"Check that a SKILL.md for '{name}' exists in a scanned directory.",
)
continue
if self.is_activated(name):
continue # Already activated, skip duplicate
+120
View File
@@ -0,0 +1,120 @@
"""CLI commands for the Hive skill system.
Phase 1 commands (AS-13):
hive skill list list discovered skills across all scopes
hive skill trust <path> permanently trust a project repo's skills
Full CLI suite (CLI-1 through CLI-13) is Phase 2.
"""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
def register_skill_commands(subparsers) -> None:
"""Register the ``hive skill`` subcommand group."""
skill_parser = subparsers.add_parser("skill", help="Manage skills")
skill_sub = skill_parser.add_subparsers(dest="skill_command", required=True)
# hive skill list
list_parser = skill_sub.add_parser("list", help="List discovered skills across all scopes")
list_parser.add_argument(
"--project-dir",
default=None,
metavar="PATH",
help="Project directory to scan (default: current directory)",
)
list_parser.set_defaults(func=cmd_skill_list)
# hive skill trust
trust_parser = skill_sub.add_parser(
"trust",
help="Permanently trust a project repository so its skills load without prompting",
)
trust_parser.add_argument(
"project_path",
help="Path to the project directory (must contain a .git with a remote origin)",
)
trust_parser.set_defaults(func=cmd_skill_trust)
def cmd_skill_list(args) -> int:
"""List all discovered skills grouped by scope."""
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
project_dir = Path(args.project_dir).resolve() if args.project_dir else Path.cwd()
skills = SkillDiscovery(DiscoveryConfig(project_root=project_dir)).discover()
if not skills:
print("No skills discovered.")
return 0
scope_headers = {
"project": "PROJECT SKILLS",
"user": "USER SKILLS",
"framework": "FRAMEWORK SKILLS",
}
for scope in ("project", "user", "framework"):
scope_skills = [s for s in skills if s.source_scope == scope]
if not scope_skills:
continue
print(f"\n{scope_headers[scope]}")
print("" * 40)
for skill in scope_skills:
print(f"{skill.name}")
print(f" {skill.description}")
print(f" {skill.location}")
return 0
def cmd_skill_trust(args) -> int:
"""Permanently trust a project repository's skills."""
from framework.skills.trust import TrustedRepoStore, _normalize_remote_url
project_path = Path(args.project_path).resolve()
if not project_path.exists():
print(f"Error: path does not exist: {project_path}", file=sys.stderr)
return 1
if not (project_path / ".git").exists():
print(
f"Error: {project_path} is not a git repository (no .git directory).",
file=sys.stderr,
)
return 1
try:
result = subprocess.run(
["git", "-C", str(project_path), "remote", "get-url", "origin"],
capture_output=True,
text=True,
timeout=3,
)
if result.returncode != 0:
print(
"Error: no remote 'origin' configured in this repository.",
file=sys.stderr,
)
return 1
remote_url = result.stdout.strip()
except subprocess.TimeoutExpired:
print("Error: git remote lookup timed out.", file=sys.stderr)
return 1
except (FileNotFoundError, OSError) as e:
print(f"Error reading git remote: {e}", file=sys.stderr)
return 1
repo_key = _normalize_remote_url(remote_url)
store = TrustedRepoStore()
store.trust(repo_key, project_path=str(project_path))
print(f"✓ Trusted: {repo_key}")
print(" Stored in ~/.hive/trusted_repos.json")
print(" Skills from this repository will load without prompting in future runs.")
return 0
+53 -4
View File
@@ -11,6 +11,7 @@ from pathlib import Path
from framework.skills.config import SkillsConfig
from framework.skills.parser import ParsedSkill, parse_skill_md
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
@@ -60,12 +61,14 @@ class DefaultSkillManager:
self._config = config or SkillsConfig()
self._skills: dict[str, ParsedSkill] = {}
self._loaded = False
self._error_count = 0
def load(self) -> None:
"""Load all enabled default skill SKILL.md files."""
if self._loaded:
return
error_count = 0
for skill_name, dir_name in SKILL_REGISTRY.items():
if not self._config.is_default_enabled(skill_name):
logger.info("Default skill '%s' disabled by config", skill_name)
@@ -73,17 +76,34 @@ class DefaultSkillManager:
skill_path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
if not skill_path.is_file():
logger.error("Default skill SKILL.md not found: %s", skill_path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_NOT_FOUND,
what=f"Default skill SKILL.md not found: '{skill_path}'",
why=f"The framework skill '{skill_name}' is missing its SKILL.md file.",
fix="Reinstall the hive framework — this file is part of the package.",
)
error_count += 1
continue
parsed = parse_skill_md(skill_path, source_scope="framework")
if parsed is None:
logger.error("Failed to parse default skill: %s", skill_path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Failed to parse default skill '{skill_name}'",
why=f"parse_skill_md returned None for '{skill_path}'.",
fix="Reinstall the hive framework — this file may be corrupted.",
)
error_count += 1
continue
self._skills[skill_name] = parsed
self._loaded = True
self._error_count = error_count
def build_protocols_prompt(self) -> str:
"""Build the combined operational protocols section.
@@ -127,8 +147,23 @@ class DefaultSkillManager:
"""Log which default skills are active and their configuration."""
if not self._skills:
logger.info("Default skills: all disabled")
return
# DX-3: Per-skill structured startup log
for skill_name in SKILL_REGISTRY:
if skill_name in self._skills:
overrides = self._config.get_default_overrides(skill_name)
status = f"loaded overrides={overrides}" if overrides else "loaded"
elif not self._config.is_default_enabled(skill_name):
status = "disabled"
else:
status = "error"
logger.info(
"skill_startup name=%s scope=framework status=%s",
skill_name,
status,
)
# Original active skills log line (preserved for backward compatibility)
active = []
for skill_name in SKILL_REGISTRY:
if skill_name in self._skills:
@@ -138,7 +173,21 @@ class DefaultSkillManager:
else:
active.append(skill_name)
logger.info("Default skills active: %s", ", ".join(active))
if active:
logger.info("Default skills active: %s", ", ".join(active))
# DX-3: Summary line with error count
total = len(SKILL_REGISTRY)
active_count = len(self._skills)
error_count = getattr(self, "_error_count", 0)
disabled_count = total - active_count - error_count
logger.info(
"Skills: %d default (%d active, %d disabled, %d error)",
total,
active_count,
disabled_count,
error_count,
)
@property
def active_skill_names(self) -> list[str]:
+8 -5
View File
@@ -11,6 +11,7 @@ from dataclasses import dataclass
from pathlib import Path
from framework.skills.parser import ParsedSkill, parse_skill_md
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
@@ -172,11 +173,13 @@ class SkillDiscovery:
for skill in skills:
if skill.name in seen:
existing = seen[skill.name]
logger.warning(
"Skill name collision: '%s' from %s overrides %s",
skill.name,
skill.location,
existing.location,
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_COLLISION,
what=f"Skill name collision: '{skill.name}'",
why=f"'{skill.location}' overrides '{existing.location}'.",
fix="Rename one of the conflicting skill directories to use a unique name.",
)
seen[skill.name] = skill
+29
View File
@@ -42,11 +42,14 @@ class SkillsManagerConfig:
When ``None``, community discovery is skipped.
skip_community_discovery: Explicitly skip community scanning
even when ``project_root`` is set.
interactive: Whether trust gating can prompt the user interactively.
When ``False``, untrusted project skills are silently skipped.
"""
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
project_root: Path | None = None
skip_community_discovery: bool = False
interactive: bool = True
class SkillsManager:
@@ -63,6 +66,7 @@ class SkillsManager:
self._loaded = False
self._catalog_prompt: str = ""
self._protocols_prompt: str = ""
self._allowlisted_dirs: list[str] = []
# ------------------------------------------------------------------
# Factory for backwards-compat bridge
@@ -85,6 +89,7 @@ class SkillsManager:
mgr._loaded = True # skip load()
mgr._catalog_prompt = skills_catalog_prompt
mgr._protocols_prompt = protocols_prompt
mgr._allowlisted_dirs = []
return mgr
# ------------------------------------------------------------------
@@ -113,9 +118,18 @@ class SkillsManager:
# 1. Community skill discovery (when project_root is available)
catalog_prompt = ""
if self._config.project_root is not None and not self._config.skip_community_discovery:
from framework.skills.trust import TrustGate
discovery = SkillDiscovery(DiscoveryConfig(project_root=self._config.project_root))
discovered = discovery.discover()
# Trust-gate project-scope skills (AS-13)
discovered = TrustGate(interactive=self._config.interactive).filter_and_gate(
discovered, project_dir=self._config.project_root
)
catalog = SkillCatalog(discovered)
self._allowlisted_dirs = catalog.allowlisted_dirs
catalog_prompt = catalog.to_prompt()
# Pre-activated community skills
@@ -132,6 +146,16 @@ class SkillsManager:
default_mgr.load()
default_mgr.log_active_skills()
protocols_prompt = default_mgr.build_protocols_prompt()
# DX-3: Community skill startup summary
if self._config.project_root is not None and not self._config.skip_community_discovery:
community_count = len(catalog._skills) if catalog_prompt else 0
pre_activated_count = len(skills_config.skills) if skills_config.skills else 0
logger.info(
"Skills: %d community (%d catalog, %d pre-activated)",
community_count,
community_count,
pre_activated_count,
)
# 3. Cache
self._catalog_prompt = catalog_prompt
@@ -160,6 +184,11 @@ class SkillsManager:
"""Default skill operational protocols for system prompt injection."""
return self._protocols_prompt
@property
def allowlisted_dirs(self) -> list[str]:
"""Skill base directories for Tier 3 resource access (AS-6)."""
return self._allowlisted_dirs
@property
def is_loaded(self) -> bool:
return self._loaded
+52
View File
@@ -0,0 +1,52 @@
"""Data models for the Hive skill system (Agent Skills standard)."""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from pathlib import Path
class SkillScope(StrEnum):
"""Where a skill was discovered."""
PROJECT = "project"
USER = "user"
FRAMEWORK = "framework"
class TrustStatus(StrEnum):
"""Trust state of a skill entry."""
TRUSTED = "trusted"
PENDING_CONSENT = "pending_consent"
DENIED = "denied"
@dataclass
class SkillEntry:
"""In-memory record for a discovered skill (PRD §4.2)."""
name: str
"""Skill name from SKILL.md frontmatter."""
description: str
"""Skill description from SKILL.md frontmatter."""
location: Path
"""Absolute path to SKILL.md."""
base_dir: Path
"""Parent directory of SKILL.md (skill root)."""
source_scope: SkillScope
"""Which scope this skill was found in."""
trust_status: TrustStatus = TrustStatus.TRUSTED
"""Trust state; project-scope skills start as PENDING_CONSENT before gating."""
# Optional frontmatter fields
license: str | None = None
compatibility: list[str] = field(default_factory=list)
allowed_tools: list[str] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
+81 -14
View File
@@ -13,6 +13,8 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
# Maximum name length before a warning is logged
@@ -74,17 +76,38 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
try:
content = path.read_text(encoding="utf-8")
except OSError as exc:
logger.error("Failed to read %s: %s", path, exc)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Failed to read '{path}'",
why=str(exc),
fix="Check the file exists and has read permissions.",
)
return None
if not content.strip():
logger.error("Empty SKILL.md: %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="The file exists but contains no content.",
fix="Add valid YAML frontmatter and a markdown body to the SKILL.md.",
)
return None
# Split on --- delimiters (first two occurrences)
parts = content.split("---", 2)
if len(parts) < 3:
logger.error("SKILL.md missing YAML frontmatter delimiters (---): %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="Missing YAML frontmatter (---).",
fix="Wrap the frontmatter with --- on its own line at the top and bottom.",
)
return None
# parts[0] is content before first --- (should be empty or whitespace)
@@ -94,7 +117,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
body = parts[2].strip()
if not raw_yaml:
logger.error("Empty YAML frontmatter in %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="The --- delimiters are present but the YAML block is empty.",
fix="Add at least 'name' and 'description' fields to the frontmatter.",
)
return None
# Parse YAML
@@ -108,19 +138,47 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
try:
fixed = _try_fix_yaml(raw_yaml)
frontmatter = yaml.safe_load(fixed)
logger.warning("Fixed YAML parse issues in %s (unquoted colons)", path)
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_YAML_FIXUP,
what=f"Auto-fixed YAML in '{path}'",
why="Unquoted colon values detected in frontmatter.",
fix='Wrap values containing colons in quotes e.g. description: "Use for: research"',
)
except yaml.YAMLError as exc:
logger.error("Unparseable YAML in %s: %s", path, exc)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why=str(exc),
fix="Validate the YAML frontmatter at https://yaml-online-parser.appspot.com/",
)
return None
if not isinstance(frontmatter, dict):
logger.error("YAML frontmatter is not a mapping in %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="YAML frontmatter is not a key-value mapping.",
fix="Ensure the frontmatter is valid YAML with key: value pairs.",
)
return None
# Required: description
description = frontmatter.get("description")
if not description or not str(description).strip():
logger.error("Missing or empty 'description' in %s — skipping skill", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_MISSING_DESCRIPTION,
what=f"Missing 'description' in '{path}'",
why="The 'description' field is required but is absent or empty.",
fix="Add a non-empty 'description' field to the YAML frontmatter.",
)
return None
# Required: name (fallback to parent directory name)
@@ -128,7 +186,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
parent_dir_name = path.parent.name
if not name or not str(name).strip():
name = parent_dir_name
logger.warning("Missing 'name' in %s — using directory name '%s'", path, name)
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_NAME_MISMATCH,
what=f"Missing 'name' in '{path}' — using directory name '{name}'",
why="The 'name' field is absent from the YAML frontmatter.",
fix=f"Add 'name: {name}' to the frontmatter to make this explicit.",
)
else:
name = str(name).strip()
@@ -137,11 +202,13 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
logger.warning("Skill name exceeds %d chars in %s: '%s'", _MAX_NAME_LENGTH, path, name)
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
logger.warning(
"Skill name '%s' doesn't match parent directory '%s' in %s",
name,
parent_dir_name,
path,
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_NAME_MISMATCH,
what=f"Name mismatch in '{path}'",
why=f"Skill name '{name}' doesn't match directory '{parent_dir_name}'.",
fix=f"Rename the directory to '{name}' or set name to '{parent_dir_name}'.",
)
return ParsedSkill(
+70
View File
@@ -0,0 +1,70 @@
"""Structured error codes and diagnostics for the Hive skill system.
Implements DX-1 (structured error codes) and DX-2 (what/why/fix format)
from the skill system PRD §7.5.
"""
from __future__ import annotations
import logging
from enum import Enum
class SkillErrorCode(Enum):
"""Standardized error codes for skill system operations (DX-1)."""
SKILL_NOT_FOUND = "SKILL_NOT_FOUND"
SKILL_PARSE_ERROR = "SKILL_PARSE_ERROR"
SKILL_ACTIVATION_FAILED = "SKILL_ACTIVATION_FAILED"
SKILL_MISSING_DESCRIPTION = "SKILL_MISSING_DESCRIPTION"
SKILL_YAML_FIXUP = "SKILL_YAML_FIXUP"
SKILL_NAME_MISMATCH = "SKILL_NAME_MISMATCH"
SKILL_COLLISION = "SKILL_COLLISION"
class SkillError(Exception):
"""Structured exception for skill system errors (DX-2).
Raised in strict validation paths. Also used as the base
format contract for log_skill_error() log messages.
"""
def __init__(self, code: SkillErrorCode, what: str, why: str, fix: str):
self.code = code
self.what = what
self.why = why
self.fix = fix
self.message = (
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
)
super().__init__(self.message)
def log_skill_error(
logger: logging.Logger,
level: str,
code: SkillErrorCode,
what: str,
why: str,
fix: str,
) -> None:
"""Emit a structured skill diagnostic log with consistent format (DX-2).
Args:
logger: The module logger to emit to.
level: Log level string 'error', 'warning', or 'info'.
code: Structured error code.
what: What failed (specific skill name and path).
why: Root cause.
fix: Concrete next step for the developer.
"""
msg = f"[{code.value}] What failed: {what} | Why: {why} | Fix: {fix}"
getattr(logger, level)(
msg,
extra={
"skill_error_code": code.value,
"what": what,
"why": why,
"fix": fix,
},
)
+477
View File
@@ -0,0 +1,477 @@
"""Trust gating for project-level skills (PRD AS-13).
Project-level skills from untrusted repositories require explicit user consent
before their instructions are loaded into the agent's system prompt.
Framework and user-scope skills are always trusted.
Trusted repos are persisted at ~/.hive/trusted_repos.json.
"""
from __future__ import annotations
import json
import logging
import subprocess
import sys
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from enum import StrEnum
from pathlib import Path
from urllib.parse import urlparse
from framework.skills.parser import ParsedSkill
logger = logging.getLogger(__name__)
# Env var to bypass trust gating in CI/headless pipelines (opt-in).
_ENV_TRUST_ALL = "HIVE_TRUST_PROJECT_SKILLS"
# Env var for comma-separated own-remote glob patterns (e.g. "github.com/myorg/*").
_ENV_OWN_REMOTES = "HIVE_OWN_REMOTES"
_TRUSTED_REPOS_PATH = Path.home() / ".hive" / "trusted_repos.json"
_NOTICE_SENTINEL_PATH = Path.home() / ".hive" / ".skill_trust_notice_shown"
# ---------------------------------------------------------------------------
# Trusted repo store
# ---------------------------------------------------------------------------
@dataclass
class TrustedRepoEntry:
repo_key: str
added_at: datetime
project_path: str = ""
class TrustedRepoStore:
"""Persists permanently-trusted repo keys to ~/.hive/trusted_repos.json."""
def __init__(self, path: Path | None = None) -> None:
self._path = path or _TRUSTED_REPOS_PATH
self._entries: dict[str, TrustedRepoEntry] = {}
self._loaded = False
def is_trusted(self, repo_key: str) -> bool:
self._ensure_loaded()
return repo_key in self._entries
def trust(self, repo_key: str, project_path: str = "") -> None:
self._ensure_loaded()
self._entries[repo_key] = TrustedRepoEntry(
repo_key=repo_key,
added_at=datetime.now(tz=UTC),
project_path=project_path,
)
self._save()
logger.info("skill_trust_store: trusted repo_key=%s", repo_key)
def revoke(self, repo_key: str) -> bool:
self._ensure_loaded()
if repo_key in self._entries:
del self._entries[repo_key]
self._save()
logger.info("skill_trust_store: revoked repo_key=%s", repo_key)
return True
return False
def list_entries(self) -> list[TrustedRepoEntry]:
self._ensure_loaded()
return list(self._entries.values())
def _ensure_loaded(self) -> None:
if not self._loaded:
self._load()
self._loaded = True
def _load(self) -> None:
try:
data = json.loads(self._path.read_text(encoding="utf-8"))
for raw in data.get("entries", []):
repo_key = raw.get("repo_key", "")
if not repo_key:
continue
try:
added_at = datetime.fromisoformat(raw["added_at"])
except (KeyError, ValueError):
added_at = datetime.now(tz=UTC)
self._entries[repo_key] = TrustedRepoEntry(
repo_key=repo_key,
added_at=added_at,
project_path=raw.get("project_path", ""),
)
except FileNotFoundError:
pass
except Exception as e:
logger.warning(
"skill_trust_store: could not read %s (%s); treating as empty",
self._path,
e,
)
def _save(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
data = {
"version": 1,
"entries": [
{
"repo_key": e.repo_key,
"added_at": e.added_at.isoformat(),
"project_path": e.project_path,
}
for e in self._entries.values()
],
}
# Atomic write: write to .tmp then rename
tmp = self._path.with_suffix(".tmp")
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
tmp.replace(self._path)
# ---------------------------------------------------------------------------
# Trust classification
# ---------------------------------------------------------------------------
class ProjectTrustClassification(StrEnum):
ALWAYS_TRUSTED = "always_trusted"
TRUSTED_BY_USER = "trusted_by_user"
UNTRUSTED = "untrusted"
class ProjectTrustDetector:
"""Classifies a project directory as trusted or untrusted.
Algorithm (PRD §4.1 trust note):
1. No project_dir ALWAYS_TRUSTED
2. No .git directory ALWAYS_TRUSTED (not a git repo)
3. No remote 'origin' ALWAYS_TRUSTED (local-only repo)
4. Remote URL repo_key; in TrustedRepoStore TRUSTED_BY_USER
5. Localhost remote ALWAYS_TRUSTED
6. ~/.hive/own_remotes match ALWAYS_TRUSTED
7. HIVE_OWN_REMOTES env match ALWAYS_TRUSTED
8. None of the above UNTRUSTED
"""
def __init__(self, store: TrustedRepoStore | None = None) -> None:
self._store = store or TrustedRepoStore()
def classify(self, project_dir: Path | None) -> tuple[ProjectTrustClassification, str]:
"""Return (classification, repo_key).
repo_key is empty string for ALWAYS_TRUSTED cases without a remote.
"""
if project_dir is None or not project_dir.exists():
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
if not (project_dir / ".git").exists():
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
remote_url = self._get_remote_origin(project_dir)
if not remote_url:
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
repo_key = _normalize_remote_url(remote_url)
# Explicitly trusted by user
if self._store.is_trusted(repo_key):
return ProjectTrustClassification.TRUSTED_BY_USER, repo_key
# Localhost remotes are always trusted
if _is_localhost_remote(remote_url):
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
# User-configured own-remote patterns
if self._matches_own_remotes(repo_key):
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
return ProjectTrustClassification.UNTRUSTED, repo_key
def _get_remote_origin(self, project_dir: Path) -> str:
"""Run git remote get-url origin. Returns empty string on any failure."""
try:
result = subprocess.run(
["git", "-C", str(project_dir), "remote", "get-url", "origin"],
capture_output=True,
text=True,
timeout=3,
)
if result.returncode == 0:
return result.stdout.strip()
except subprocess.TimeoutExpired:
logger.warning(
"skill_trust: git remote lookup timed out for %s; treating as trusted",
project_dir,
)
except (FileNotFoundError, OSError):
pass # git not found or other OS error
return ""
def _matches_own_remotes(self, repo_key: str) -> bool:
"""Check repo_key against user-configured own-remote glob patterns."""
import fnmatch
patterns: list[str] = []
# From env var
env_patterns = _ENV_OWN_REMOTES
import os
raw = os.environ.get(env_patterns, "")
if raw:
patterns.extend(p.strip() for p in raw.split(",") if p.strip())
# From ~/.hive/own_remotes file
own_remotes_file = Path.home() / ".hive" / "own_remotes"
if own_remotes_file.is_file():
try:
for line in own_remotes_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if line and not line.startswith("#"):
patterns.append(line)
except OSError:
pass
return any(fnmatch.fnmatch(repo_key, p) for p in patterns)
# ---------------------------------------------------------------------------
# URL helpers (public so CLI can reuse)
# ---------------------------------------------------------------------------
def _normalize_remote_url(url: str) -> str:
"""Normalize a git remote URL to a canonical ``host/org/repo`` key.
Examples:
git@github.com:org/repo.git github.com/org/repo
https://github.com/org/repo github.com/org/repo
ssh://git@github.com/org/repo.git github.com/org/repo
"""
url = url.strip()
# SCP-style SSH: git@github.com:org/repo.git
if url.startswith("git@") and ":" in url and "://" not in url:
url = url[4:] # strip git@
url = url.replace(":", "/", 1)
elif "://" in url:
parsed = urlparse(url)
host = parsed.hostname or ""
path = parsed.path.lstrip("/")
url = f"{host}/{path}"
# Strip .git suffix
if url.endswith(".git"):
url = url[:-4]
return url.lower().strip("/")
def _is_localhost_remote(remote_url: str) -> bool:
"""Return True if the remote points to a local host."""
local_hosts = {"localhost", "127.0.0.1", "::1"}
try:
if "://" in remote_url:
parsed = urlparse(remote_url)
return (parsed.hostname or "").lower() in local_hosts
# SCP-style: git@localhost:org/repo
if "@" in remote_url:
host_part = remote_url.split("@", 1)[1].split(":")[0]
return host_part.lower() in local_hosts
except Exception:
pass
return False
# ---------------------------------------------------------------------------
# Trust gate
# ---------------------------------------------------------------------------
class TrustGate:
"""Filters skill list, running consent flow for untrusted project-scope skills.
Framework and user-scope skills are always allowed through.
Project-scope skills from untrusted repos require consent.
"""
def __init__(
self,
store: TrustedRepoStore | None = None,
detector: ProjectTrustDetector | None = None,
interactive: bool = True,
print_fn: Callable[[str], None] | None = None,
input_fn: Callable[[str], str] | None = None,
) -> None:
self._store = store or TrustedRepoStore()
self._detector = detector or ProjectTrustDetector(self._store)
self._interactive = interactive
self._print = print_fn or print
self._input = input_fn or input
def filter_and_gate(
self,
skills: list[ParsedSkill],
project_dir: Path | None,
) -> list[ParsedSkill]:
"""Return the subset of skills that are trusted for loading.
- Framework and user-scope skills: always included.
- Project-scope skills: classified; consent prompt shown if untrusted.
"""
import os
# Separate project skills from always-trusted scopes
always_trusted = [s for s in skills if s.source_scope != "project"]
project_skills = [s for s in skills if s.source_scope == "project"]
if not project_skills:
return always_trusted
# Env-var CI override: trust all project skills for this invocation
if os.environ.get(_ENV_TRUST_ALL, "").strip() == "1":
logger.info(
"skill_trust: %s=1 set; trusting %d project skill(s) without consent",
_ENV_TRUST_ALL,
len(project_skills),
)
return always_trusted + project_skills
classification, repo_key = self._detector.classify(project_dir)
if classification in (
ProjectTrustClassification.ALWAYS_TRUSTED,
ProjectTrustClassification.TRUSTED_BY_USER,
):
logger.info(
"skill_trust: project skills trusted classification=%s repo=%s count=%d",
classification,
repo_key or "(no remote)",
len(project_skills),
)
return always_trusted + project_skills
# UNTRUSTED — need consent
if not self._interactive or not sys.stdin.isatty():
logger.warning(
"skill_trust: skipping %d project-scope skill(s) from untrusted repo "
"'%s' (non-interactive mode). "
"To trust permanently run: hive skill trust %s",
len(project_skills),
repo_key,
project_dir or ".",
)
logger.info(
"skill_trust_decision repo=%s skills=%d decision=denied mode=headless",
repo_key,
len(project_skills),
)
return always_trusted
# Interactive consent flow
decision = self._run_consent_flow(project_skills, project_dir, repo_key)
logger.info(
"skill_trust_decision repo=%s skills=%d decision=%s mode=interactive",
repo_key,
len(project_skills),
decision,
)
if decision == "session":
return always_trusted + project_skills
if decision == "permanent":
self._store.trust(repo_key, project_path=str(project_dir or ""))
return always_trusted + project_skills
# denied
return always_trusted
def _run_consent_flow(
self,
project_skills: list[ParsedSkill],
project_dir: Path | None,
repo_key: str,
) -> str:
"""Show the security notice (once) and consent prompt.
Return 'session' | 'permanent' | 'denied'."""
from framework.credentials.setup import Colors
if not sys.stdout.isatty():
Colors.disable()
self._maybe_show_security_notice(Colors)
self._print_consent_prompt(project_skills, project_dir, repo_key, Colors)
return self._prompt_consent(Colors)
def _maybe_show_security_notice(self, Colors) -> None: # noqa: N803
"""Show the one-time security notice if not already shown (NFR-5)."""
if _NOTICE_SENTINEL_PATH.exists():
return
self._print("")
self._print(
f"{Colors.YELLOW}Security notice:{Colors.NC} Skills inject instructions "
"into the agent's system prompt."
)
self._print(
" Only load skills from sources you trust. "
"Registry skills at tier 'verified' or 'official' have been audited."
)
self._print("")
try:
_NOTICE_SENTINEL_PATH.parent.mkdir(parents=True, exist_ok=True)
_NOTICE_SENTINEL_PATH.touch()
except OSError:
pass
def _print_consent_prompt(
self,
project_skills: list[ParsedSkill],
project_dir: Path | None,
repo_key: str,
Colors, # noqa: N803
) -> None:
p = self._print
p("")
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
p(f"{Colors.BOLD} SKILL TRUST REQUIRED{Colors.NC}")
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
p("")
proj_label = str(project_dir) if project_dir else "this project"
p(
f" The project at {Colors.CYAN}{proj_label}{Colors.NC} wants to load "
f"{len(project_skills)} skill(s)"
)
p(" that will inject instructions into the agent's system prompt.")
if repo_key:
p(f" Source: {Colors.BOLD}{repo_key}{Colors.NC}")
p("")
p(" Skills requesting access:")
for skill in project_skills:
p(f" {Colors.CYAN}{Colors.NC} {Colors.BOLD}{skill.name}{Colors.NC}")
p(f' "{skill.description}"')
p(f" {Colors.DIM}{skill.location}{Colors.NC}")
p("")
p(" Options:")
p(f" {Colors.CYAN}1){Colors.NC} Trust this session only")
p(f" {Colors.CYAN}2){Colors.NC} Trust permanently — remember for future runs")
p(
f" {Colors.DIM}3) Deny"
f" — skip all project-scope skills from this repo{Colors.NC}"
)
p(f"{Colors.YELLOW}{'' * 60}{Colors.NC}")
def _prompt_consent(self, Colors) -> str: # noqa: N803
"""Prompt until a valid choice is entered. Returns 'session'|'permanent'|'denied'."""
mapping = {"1": "session", "2": "permanent", "3": "denied"}
while True:
try:
choice = self._input("Select option (1-3): ").strip()
if choice in mapping:
return mapping[choice]
except (KeyboardInterrupt, EOFError):
return "denied"
self._print(f"{Colors.RED}Invalid choice. Enter 1, 2, or 3.{Colors.NC}")
+103 -15
View File
@@ -8,8 +8,11 @@ written by the queen directly.
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from framework.runner.tool_registry import ToolRegistry
@@ -34,7 +37,7 @@ def write_to_diary(entry: str) -> str:
return "Diary entry recorded."
def recall_diary(query: str = "", days_back: int = 7) -> str:
async def recall_diary(query: str = "", days_back: int = 7) -> str:
"""Search recent diary entries (episodic memory).
Use this when the user asks about what happened in the past "what did we
@@ -45,26 +48,112 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
Args:
query: Optional keyword or phrase to filter entries. If empty, all
recent entries are returned.
days_back: How many days to look back (130). Defaults to 7.
days_back: How many days to look back (1-30). Defaults to 7.
"""
from datetime import date, timedelta
from framework.agents.queen.queen_memory import read_episodic_memory
from framework.agents.queen.queen_memory_index import (
embeddings_enabled,
hybrid_search,
load_index,
record_retrieval,
resolve_prose,
semantic_search,
)
days_back = max(1, min(days_back, 30))
today = date.today()
results: list[str] = []
total_chars = 0
char_budget = 12_000
# ------------------------------------------------------------------
# Semantic path — used when embedding model is configured and query given
# ------------------------------------------------------------------
if query and embeddings_enabled():
logger.info("queen_memory: semantic recall — query=%r days_back=%d", query, days_back)
oldest = (today - timedelta(days=days_back - 1)).strftime("%Y-%m-%d")
newest = today.strftime("%Y-%m-%d")
index = load_index()
sem_results = await semantic_search(
query, index, k=30, date_range=(oldest, newest)
)
if sem_results:
sem_scores = dict(sem_results)
candidate_ids = [eid for eid, _ in sem_results]
ranked = hybrid_search(query, index, candidate_ids, sem_scores)
results: list[str] = []
total_chars = 0
returned_ids: list[str] = []
for entry_id, _score in ranked:
date_str, ts = entry_id.split(":", 1)
prose = resolve_prose(entry_id)
if not prose:
continue
# Format label from date_str
try:
y, m, d_int = map(int, date_str.split("-"))
d = date(y, m, d_int)
label = d.strftime("%B %-d, %Y")
if d == today:
label = f"Today — {label}"
except ValueError:
label = date_str
section = f"## {label} ({ts})\n\n{prose}"
# Also include linked neighbours (Phase 3 expansion)
raw = index.get("entries", {}).get(entry_id, {})
related_prose_parts: list[str] = []
for related_id in raw.get("related", [])[:2]:
if related_id in (eid for eid, _ in ranked):
continue # will appear in main results
rp = resolve_prose(related_id)
if rp:
r_date_str, r_ts = related_id.split(":", 1)
try:
ry, rm, rd = map(int, r_date_str.split("-"))
r_label = date(ry, rm, rd).strftime("%B %-d, %Y")
except ValueError:
r_label = r_date_str
related_prose_parts.append(
f"_Related ({r_label} {r_ts}):_ {rp[:300]}"
)
if related_prose_parts:
section += "\n\n" + "\n\n".join(related_prose_parts)
if total_chars + len(section) > char_budget:
remaining = char_budget - total_chars
if remaining > 200:
section = section[: remaining - 100] + "\n\n…(truncated)"
results.append(section)
returned_ids.append(entry_id)
break
results.append(section)
returned_ids.append(entry_id)
total_chars += len(section)
if results:
record_retrieval(index, returned_ids)
return "\n\n---\n\n".join(results)
# Fall through to substring if semantic found nothing useful
# ------------------------------------------------------------------
# Substring fallback — original behaviour, unchanged
# ------------------------------------------------------------------
results_fb: list[str] = []
total_chars_fb = 0
for offset in range(days_back):
d = today - timedelta(days=offset)
content = read_episodic_memory(d)
if not content:
continue
# If a query is given, only include entries that mention it
if query:
# Check each section (split by ###) for relevance
sections = content.split("### ")
matched = [s for s in sections if query.lower() in s.lower()]
if not matched:
@@ -74,24 +163,23 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
if d == today:
label = f"Today — {label}"
entry = f"## {label}\n\n{content}"
if total_chars + len(entry) > char_budget:
remaining = char_budget - total_chars
if total_chars_fb + len(entry) > char_budget:
remaining = char_budget - total_chars_fb
if remaining > 200:
# Fit a partial entry within budget
trimmed = content[: remaining - 100] + "\n\n…(truncated)"
results.append(f"## {label}\n\n{trimmed}")
results_fb.append(f"## {label}\n\n{trimmed}")
else:
results.append(f"## {label}\n\n(truncated — hit size limit)")
results_fb.append(f"## {label}\n\n(truncated — hit size limit)")
break
results.append(entry)
total_chars += len(entry)
results_fb.append(entry)
total_chars_fb += len(entry)
if not results:
if not results_fb:
if query:
return f"No diary entries matching '{query}' in the last {days_back} days."
return f"No diary entries found in the last {days_back} days."
return "\n\n---\n\n".join(results)
return "\n\n---\n\n".join(results_fb)
def register_queen_memory_tools(registry: ToolRegistry) -> None:
-8
View File
@@ -60,7 +60,6 @@
"integrity": "sha512-CGOfOJqWjg2qW/Mb6zNsDm+u5vFQ8DxXfbM09z69p5Z6+mE1ikP2jUXw+j42Pf1XTYED2Rni5f95npYeuwMDQA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@babel/code-frame": "^7.29.0",
"@babel/generator": "^7.29.0",
@@ -1557,7 +1556,6 @@
"integrity": "sha512-4K3bqJpXpqfg2XKGK9bpDTc6xO/xoUP/RBWS7AtRMug6zZFaRekiLzjVtAoZMquxoAbzBvy5nxQ7veS5eYzf8A==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~7.18.0"
}
@@ -1573,7 +1571,6 @@
"resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.28.tgz",
"integrity": "sha512-z9VXpC7MWrhfWipitjNdgCauoMLRdIILQsAEV+ZesIzBq/oUlxk0m3ApZuMFCXdnS4U7KrI+l3WRUEGQ8K1QKw==",
"license": "MIT",
"peer": true,
"dependencies": {
"@types/prop-types": "*",
"csstype": "^3.2.2"
@@ -1786,7 +1783,6 @@
}
],
"license": "MIT",
"peer": true,
"dependencies": {
"baseline-browser-mapping": "^2.9.0",
"caniuse-lite": "^1.0.30001759",
@@ -3564,7 +3560,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -3616,7 +3611,6 @@
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"loose-envify": "^1.1.0"
},
@@ -3629,7 +3623,6 @@
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
"license": "MIT",
"peer": true,
"dependencies": {
"loose-envify": "^1.1.0",
"scheduler": "^0.23.2"
@@ -4190,7 +4183,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",
+2 -2
View File
@@ -34,8 +34,8 @@ export const executionApi = {
graph_id: graphId,
}),
chat: (sessionId: string, message: string) =>
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message }),
chat: (sessionId: string, message: string, images?: { type: string; image_url: { url: string } }[]) =>
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message, ...(images?.length ? { images } : {}) }),
/** Queue context for the queen without triggering an LLM response. */
queenContext: (sessionId: string, message: string) =>
+4
View File
@@ -81,6 +81,10 @@ export const sessionsApi = {
eventsHistory: (sessionId: string) =>
api.get<{ events: AgentEvent[]; session_id: string }>(`/sessions/${sessionId}/events/history`),
/** Open the session's data folder in the OS file manager. */
revealFolder: (sessionId: string) =>
api.post<{ path: string }>(`/sessions/${sessionId}/reveal`),
/** List all queen sessions on disk — live + cold (post-restart). */
history: () =>
api.get<{ sessions: Array<{ session_id: string; cold: boolean; live: boolean; has_messages: boolean; created_at: number; agent_name?: string | null; agent_path?: string | null }> }>("/sessions/history"),
+3
View File
@@ -14,6 +14,8 @@ export interface LiveSession {
intro_message?: string;
/** Queen operating phase — "planning", "building", "staging", or "running" */
queen_phase?: "planning" | "building" | "staging" | "running";
/** Whether the queen's LLM supports image content in messages */
queen_supports_images?: boolean;
/** Present in 409 conflict responses when worker is still loading */
loading?: boolean;
}
@@ -324,6 +326,7 @@ export type EventTypeName =
| "node_retry"
| "edge_traversed"
| "context_compacted"
| "context_usage_updated"
| "webhook_received"
| "custom"
| "escalation_requested"
+493 -109
View File
@@ -1,8 +1,32 @@
import { memo, useState, useRef, useEffect } from "react";
import { Send, Square, Crown, Cpu, Check, Loader2 } from "lucide-react";
import { memo, useState, useRef, useEffect, useMemo } from "react";
import {
Send,
Square,
Crown,
Cpu,
Check,
Loader2,
Paperclip,
X,
} from "lucide-react";
export interface ImageContent {
type: "image_url";
image_url: { url: string };
}
export interface ContextUsageEntry {
usagePct: number;
messageCount: number;
estimatedTokens: number;
maxTokens: number;
}
import MarkdownContent from "@/components/MarkdownContent";
import QuestionWidget from "@/components/QuestionWidget";
import MultiQuestionWidget from "@/components/MultiQuestionWidget";
import ParallelSubagentBubble, {
type SubagentGroup,
} from "@/components/ParallelSubagentBubble";
export interface ChatMessage {
id: string;
@@ -10,7 +34,13 @@ export interface ChatMessage {
agentColor: string;
content: string;
timestamp: string;
type?: "system" | "agent" | "user" | "tool_status" | "worker_input_request" | "run_divider";
type?:
| "system"
| "agent"
| "user"
| "tool_status"
| "worker_input_request"
| "run_divider";
role?: "queen" | "worker";
/** Which worker thread this message belongs to (worker agent name) */
thread?: string;
@@ -18,11 +48,17 @@ export interface ChatMessage {
createdAt?: number;
/** Queen phase active when this message was created */
phase?: "planning" | "building" | "staging" | "running";
/** Images attached to a user message */
images?: ImageContent[];
/** Backend node_id that produced this message — used for subagent grouping */
nodeId?: string;
/** Backend execution_id for this message */
executionId?: string;
}
interface ChatPanelProps {
messages: ChatMessage[];
onSend: (message: string, thread: string) => void;
onSend: (message: string, thread: string, images?: ImageContent[]) => void;
isWaiting?: boolean;
/** When true a worker is thinking (not yet streaming) */
isWorkerWaiting?: boolean;
@@ -31,6 +67,8 @@ interface ChatPanelProps {
activeThread: string;
/** When true, the input is disabled (e.g. during loading) */
disabled?: boolean;
/** When false, the image attach button is hidden (model lacks vision support) */
supportsImages?: boolean;
/** Called when user clicks the stop button to cancel the queen's current turn */
onCancel?: () => void;
/** Pending question from ask_user — replaces textarea when present */
@@ -38,7 +76,9 @@ interface ChatPanelProps {
/** Options for the pending question */
pendingOptions?: string[] | null;
/** Multiple questions from ask_user_multiple */
pendingQuestions?: { id: string; prompt: string; options?: string[] }[] | null;
pendingQuestions?:
| { id: string; prompt: string; options?: string[] }[]
| null;
/** Called when user submits an answer to the pending question */
onQuestionSubmit?: (answer: string, isOther: boolean) => void;
/** Called when user submits answers to multiple questions */
@@ -47,6 +87,8 @@ interface ChatPanelProps {
onQuestionDismiss?: () => void;
/** Queen operating phase — shown as a tag on queen messages */
queenPhase?: "planning" | "building" | "staging" | "running";
/** Context window usage for queen and workers */
contextUsage?: Record<string, ContextUsageEntry>;
}
const queenColor = "hsl(45,95%,58%)";
@@ -72,7 +114,8 @@ const TOOL_HEX = [
function toolHex(name: string): string {
let hash = 0;
for (let i = 0; i < name.length; i++) hash = (hash * 31 + name.charCodeAt(i)) | 0;
for (let i = 0; i < name.length; i++)
hash = (hash * 31 + name.charCodeAt(i)) | 0;
return TOOL_HEX[Math.abs(hash) % TOOL_HEX.length];
}
@@ -120,12 +163,18 @@ function ToolActivityRow({ content }: { content: string }) {
<span
key={`run-${p.name}`}
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
style={{
color: hex,
backgroundColor: `${hex}18`,
border: `1px solid ${hex}35`,
}}
>
<Loader2 className="w-2.5 h-2.5 animate-spin" />
{p.name}
{p.count > 1 && (
<span className="text-[10px] font-medium opacity-70">×{p.count}</span>
<span className="text-[10px] font-medium opacity-70">
×{p.count}
</span>
)}
</span>
);
@@ -136,7 +185,11 @@ function ToolActivityRow({ content }: { content: string }) {
<span
key={`done-${p.name}`}
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
style={{
color: hex,
backgroundColor: `${hex}18`,
border: `1px solid ${hex}35`,
}}
>
<Check className="w-2.5 h-2.5" />
{p.name}
@@ -151,109 +204,249 @@ function ToolActivityRow({ content }: { content: string }) {
);
}
const MessageBubble = memo(function MessageBubble({ msg, queenPhase }: { msg: ChatMessage; queenPhase?: "planning" | "building" | "staging" | "running" }) {
const isUser = msg.type === "user";
const isQueen = msg.role === "queen";
const color = getColor(msg.agent, msg.role);
const MessageBubble = memo(
function MessageBubble({
msg,
queenPhase,
}: {
msg: ChatMessage;
queenPhase?: "planning" | "building" | "staging" | "running";
}) {
const isUser = msg.type === "user";
const isQueen = msg.role === "queen";
const color = getColor(msg.agent, msg.role);
if (msg.type === "run_divider") {
return (
<div className="flex items-center gap-3 py-2 my-1">
<div className="flex-1 h-px bg-border/60" />
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
{msg.content}
</span>
<div className="flex-1 h-px bg-border/60" />
</div>
);
}
if (msg.type === "system") {
return (
<div className="flex justify-center py-1">
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
{msg.content}
</span>
</div>
);
}
if (msg.type === "tool_status") {
return <ToolActivityRow content={msg.content} />;
}
if (isUser) {
return (
<div className="flex justify-end">
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
if (msg.type === "run_divider") {
return (
<div className="flex items-center gap-3 py-2 my-1">
<div className="flex-1 h-px bg-border/60" />
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
{msg.content}
</span>
<div className="flex-1 h-px bg-border/60" />
</div>
</div>
);
}
);
}
return (
<div className="flex gap-3">
<div
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
style={{
backgroundColor: `${color}18`,
border: `1.5px solid ${color}35`,
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
}}
>
{isQueen ? (
<Crown className="w-4 h-4" style={{ color }} />
) : (
<Cpu className="w-3.5 h-3.5" style={{ color }} />
)}
</div>
<div className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}>
<div className="flex items-center gap-2 mb-1">
<span className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`} style={{ color }}>
{msg.agent}
</span>
<span
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
isQueen ? "bg-primary/15 text-primary" : "bg-muted text-muted-foreground"
}`}
>
{isQueen
? ((msg.phase ?? queenPhase) === "running"
? "running"
: (msg.phase ?? queenPhase) === "staging"
? "staging"
: (msg.phase ?? queenPhase) === "planning"
? "planning"
: "building")
: "Worker"}
if (msg.type === "system") {
return (
<div className="flex justify-center py-1">
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
{msg.content}
</span>
</div>
);
}
if (msg.type === "tool_status") {
return <ToolActivityRow content={msg.content} />;
}
if (isUser) {
return (
<div className="flex justify-end">
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
{msg.images && msg.images.length > 0 && (
<div className="flex flex-wrap gap-2 mb-2">
{msg.images.map((img, i) => (
<img
key={i}
src={img.image_url.url}
alt={`attachment ${i + 1}`}
className="max-h-48 max-w-full rounded-lg object-contain"
/>
))}
</div>
)}
{msg.content && (
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
)}
</div>
</div>
);
}
return (
<div className="flex gap-3">
<div
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
style={{
backgroundColor: `${color}18`,
border: `1.5px solid ${color}35`,
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
}}
>
{isQueen ? (
<Crown className="w-4 h-4" style={{ color }} />
) : (
<Cpu className="w-3.5 h-3.5" style={{ color }} />
)}
</div>
<div
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
}`}
className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}
>
<MarkdownContent content={msg.content} />
<div className="flex items-center gap-2 mb-1">
<span
className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`}
style={{ color }}
>
{msg.agent}
</span>
<span
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
isQueen
? "bg-primary/15 text-primary"
: "bg-muted text-muted-foreground"
}`}
>
{isQueen
? (msg.phase ?? queenPhase) === "running"
? "running"
: (msg.phase ?? queenPhase) === "staging"
? "staging"
: (msg.phase ?? queenPhase) === "planning"
? "planning"
: "building"
: "Worker"}
</span>
</div>
<div
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
}`}
>
<MarkdownContent content={msg.content} />
</div>
</div>
</div>
</div>
);
}, (prev, next) => prev.msg.id === next.msg.id && prev.msg.content === next.msg.content && prev.msg.phase === next.msg.phase && prev.queenPhase === next.queenPhase);
);
},
(prev, next) =>
prev.msg.id === next.msg.id &&
prev.msg.content === next.msg.content &&
prev.msg.phase === next.msg.phase &&
prev.queenPhase === next.queenPhase,
);
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase }: ChatPanelProps) {
export default function ChatPanel({
messages,
onSend,
isWaiting,
isWorkerWaiting,
isBusy,
activeThread,
disabled,
onCancel,
pendingQuestion,
pendingOptions,
pendingQuestions,
onQuestionSubmit,
onMultiQuestionSubmit,
onQuestionDismiss,
queenPhase,
contextUsage,
supportsImages = true,
}: ChatPanelProps) {
const [input, setInput] = useState("");
const [pendingImages, setPendingImages] = useState<ImageContent[]>([]);
const [readMap, setReadMap] = useState<Record<string, number>>({});
const bottomRef = useRef<HTMLDivElement>(null);
const scrollRef = useRef<HTMLDivElement>(null);
const stickToBottom = useRef(true);
const textareaRef = useRef<HTMLTextAreaElement>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const threadMessages = messages.filter((m) => {
if (m.type === "system" && !m.thread) return false;
return m.thread === activeThread;
if (m.thread !== activeThread) return false;
// Hide queen messages whose content is whitespace-only — these are
// tool-use-only turns that have no visible text. During live operation
// tool pills provide context, but on resume the pills are gone so
// the empty bubble is meaningless.
if (m.role === "queen" && !m.type && (!m.content || !m.content.trim()))
return false;
return true;
});
// Group subagent messages into parallel bubbles.
// A subagent message has nodeId containing ":subagent:".
// The run only ends on hard boundaries (user messages, run_dividers)
// so interleaved queen/tool/system messages don't fragment the bubble.
type RenderItem =
| { kind: "message"; msg: ChatMessage }
| { kind: "parallel"; groupId: string; groups: SubagentGroup[] };
const renderItems = useMemo<RenderItem[]>(() => {
const items: RenderItem[] = [];
let i = 0;
while (i < threadMessages.length) {
const msg = threadMessages[i];
const isSubagent = msg.nodeId?.includes(":subagent:");
if (!isSubagent) {
items.push({ kind: "message", msg });
i++;
continue;
}
// Start a subagent run. Collect all subagent messages, allowing
// non-subagent messages in between (they render as normal items
// before the bubble). Only break on hard boundaries.
const subagentMsgs: ChatMessage[] = [];
const interleaved: { idx: number; msg: ChatMessage }[] = [];
const firstId = msg.id;
while (i < threadMessages.length) {
const m = threadMessages[i];
const isSa = m.nodeId?.includes(":subagent:");
if (isSa) {
subagentMsgs.push(m);
i++;
continue;
}
// Hard boundary — stop the run
if (m.type === "user" || m.type === "run_divider") break;
// Worker message from a non-subagent node means the graph has
// moved on to the next stage. Close the bubble even if some
// subagents are still streaming in the background.
if (m.role === "worker" && m.nodeId && !m.nodeId.includes(":subagent:"))
break;
// Soft interruption (queen output, system, tool_status without
// nodeId) — render it normally but keep the subagent run going
interleaved.push({ idx: items.length + interleaved.length, msg: m });
i++;
}
// Emit interleaved messages first (before the bubble)
for (const { msg: im } of interleaved) {
items.push({ kind: "message", msg: im });
}
// Build the single parallel bubble from all collected subagent msgs
if (subagentMsgs.length > 0) {
const byNode = new Map<string, ChatMessage[]>();
for (const m of subagentMsgs) {
const nid = m.nodeId!;
if (!byNode.has(nid)) byNode.set(nid, []);
byNode.get(nid)!.push(m);
}
const groups: SubagentGroup[] = [];
for (const [nodeId, msgs] of byNode) {
groups.push({
nodeId,
messages: msgs,
contextUsage: contextUsage?.[nodeId],
});
}
items.push({ kind: "parallel", groupId: `par-${firstId}`, groups });
}
}
return items;
}, [threadMessages, contextUsage]);
// Mark current thread as read
useEffect(() => {
const count = messages.filter((m) => m.thread === activeThread).length;
@@ -284,26 +477,64 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
if (!input.trim()) return;
onSend(input.trim(), activeThread);
if (!input.trim() && pendingImages.length === 0) return;
onSend(
input.trim(),
activeThread,
pendingImages.length > 0 ? pendingImages : undefined,
);
setInput("");
setPendingImages([]);
if (textareaRef.current) textareaRef.current.style.height = "auto";
};
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const files = Array.from(e.target.files ?? []);
if (files.length === 0) return;
files.forEach((file) => {
const reader = new FileReader();
reader.onload = (ev) => {
const url = ev.target?.result as string;
setPendingImages((prev) => [
...prev,
{ type: "image_url", image_url: { url } },
]);
};
reader.readAsDataURL(file);
});
// Reset so the same file can be re-selected
e.target.value = "";
};
return (
<div className="flex flex-col h-full min-w-0">
{/* Compact sub-header */}
<div className="px-5 pt-4 pb-2 flex items-center gap-2">
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">Conversation</p>
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">
Conversation
</p>
</div>
{/* Messages */}
<div ref={scrollRef} onScroll={handleScroll} className="flex-1 overflow-auto px-5 py-4 space-y-3">
{threadMessages.map((msg) => (
<div key={msg.id}>
<MessageBubble msg={msg} queenPhase={queenPhase} />
</div>
))}
<div
ref={scrollRef}
onScroll={handleScroll}
className="flex-1 overflow-auto px-5 py-4 space-y-3"
>
{renderItems.map((item) =>
item.kind === "parallel" ? (
<div key={item.groupId}>
<ParallelSubagentBubble
groupId={item.groupId}
groups={item.groups}
/>
</div>
) : (
<div key={item.msg.id}>
<MessageBubble msg={item.msg} queenPhase={queenPhase} />
</div>
),
)}
{/* Show typing indicator while waiting for first queen response (disabled + empty chat) */}
{(isWaiting || (disabled && threadMessages.length === 0)) && (
@@ -320,9 +551,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
</div>
<div className="border border-primary/20 bg-primary/5 rounded-2xl rounded-tl-md px-4 py-3">
<div className="flex gap-1.5">
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "0ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "150ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "300ms" }}
/>
</div>
</div>
</div>
@@ -340,9 +580,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
</div>
<div className="bg-muted/60 rounded-2xl rounded-tl-md px-4 py-3">
<div className="flex gap-1.5">
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "0ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "150ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "300ms" }}
/>
</div>
</div>
</div>
@@ -350,8 +599,99 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
<div ref={bottomRef} />
</div>
{/* Context window usage bar — sits between messages and input */}
{(() => {
if (!contextUsage) return null;
const queenUsage = contextUsage["__queen__"];
const workerEntries = Object.entries(contextUsage).filter(
([k]) => k !== "__queen__",
);
const workerUsage =
workerEntries.length > 0
? workerEntries.reduce(
(best, [, v]) => (v.usagePct > best.usagePct ? v : best),
workerEntries[0][1],
)
: undefined;
if (!queenUsage && !workerUsage) return null;
return (
<div className="flex items-center gap-3 mx-4 px-3 py-1 rounded-lg bg-muted/30 border border-border/20 group/ctx flex-shrink-0">
{queenUsage && (
<div
className="flex items-center gap-2 flex-1 min-w-0"
title={`Queen: ${(queenUsage.estimatedTokens / 1000).toFixed(1)}k / ${(queenUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${queenUsage.messageCount} messages`}
>
<Crown
className="w-3 h-3 flex-shrink-0"
style={{ color: "hsl(45,95%,58%)" }}
/>
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
<div
className="h-full rounded-full transition-all duration-500 ease-out"
style={{
width: `${Math.min(queenUsage.usagePct, 100)}%`,
backgroundColor:
queenUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: queenUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(45,95%,58%)",
}}
/>
</div>
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
<span className="group-hover/ctx:hidden">
{queenUsage.usagePct}%
</span>
<span className="hidden group-hover/ctx:inline">
{(queenUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
{(queenUsage.maxTokens / 1000).toFixed(0)}k
</span>
</span>
</div>
)}
{workerUsage && (
<div
className="flex items-center gap-2 flex-1 min-w-0"
title={`Worker: ${(workerUsage.estimatedTokens / 1000).toFixed(1)}k / ${(workerUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${workerUsage.messageCount} messages`}
>
<Cpu
className="w-3 h-3 flex-shrink-0"
style={{ color: "hsl(220,60%,55%)" }}
/>
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
<div
className="h-full rounded-full transition-all duration-500 ease-out"
style={{
width: `${Math.min(workerUsage.usagePct, 100)}%`,
backgroundColor:
workerUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: workerUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(220,60%,55%)",
}}
/>
</div>
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
<span className="group-hover/ctx:hidden">
{workerUsage.usagePct}%
</span>
<span className="hidden group-hover/ctx:inline">
{(workerUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
{(workerUsage.maxTokens / 1000).toFixed(0)}k
</span>
</span>
</div>
)}
</div>
);
})()}
{/* Input area — question widget replaces textarea when a question is pending */}
{pendingQuestions && pendingQuestions.length >= 2 && onMultiQuestionSubmit ? (
{pendingQuestions &&
pendingQuestions.length >= 2 &&
onMultiQuestionSubmit ? (
<MultiQuestionWidget
questions={pendingQuestions}
onSubmit={onMultiQuestionSubmit}
@@ -366,7 +706,47 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
/>
) : (
<form onSubmit={handleSubmit} className="p-4">
{/* Image preview strip */}
{pendingImages.length > 0 && (
<div className="flex flex-wrap gap-2 mb-2 px-1">
{pendingImages.map((img, i) => (
<div key={i} className="relative group">
<img
src={img.image_url.url}
alt={`preview ${i + 1}`}
className="h-16 w-16 object-cover rounded-lg border border-border"
/>
<button
type="button"
onClick={() =>
setPendingImages((prev) => prev.filter((_, j) => j !== i))
}
className="absolute -top-1.5 -right-1.5 w-4 h-4 rounded-full bg-destructive text-destructive-foreground flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
>
<X className="w-2.5 h-2.5" />
</button>
</div>
))}
</div>
)}
<div className="flex items-center gap-3 bg-muted/40 rounded-xl px-4 py-2.5 border border-border focus-within:border-primary/40 transition-colors">
<input
ref={fileInputRef}
type="file"
accept="image/*"
multiple
className="hidden"
onChange={handleFileChange}
/>
<button
type="button"
disabled={disabled || !supportsImages}
onClick={() => supportsImages && fileInputRef.current?.click()}
className="flex-shrink-0 p-1 rounded-md text-muted-foreground hover:text-foreground disabled:opacity-30 transition-colors"
title={supportsImages ? "Attach image" : "Image not supported by the current model"}
>
<Paperclip className="w-4 h-4" />
</button>
<textarea
ref={textareaRef}
rows={1}
@@ -383,7 +763,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
handleSubmit(e);
}
}}
placeholder={disabled ? "Connecting to agent..." : "Message Queen Bee..."}
placeholder={
disabled ? "Connecting to agent..." : "Message Queen Bee..."
}
disabled={disabled}
className="flex-1 bg-transparent text-sm text-foreground outline-none placeholder:text-muted-foreground disabled:opacity-50 disabled:cursor-not-allowed resize-none overflow-y-auto"
/>
@@ -398,7 +780,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
) : (
<button
type="submit"
disabled={!input.trim() || disabled}
disabled={
(!input.trim() && pendingImages.length === 0) || disabled
}
className="p-2 rounded-lg bg-primary text-primary-foreground disabled:opacity-30 hover:opacity-90 transition-opacity"
>
<Send className="w-4 h-4" />
@@ -28,6 +28,13 @@ export interface SubagentReport {
status?: "running" | "complete" | "error";
}
interface ContextUsage {
usagePct: number;
messageCount: number;
estimatedTokens: number;
maxTokens: number;
}
interface NodeDetailPanelProps {
node: GraphNode | null;
nodeSpec?: NodeSpec | null;
@@ -38,6 +45,7 @@ interface NodeDetailPanelProps {
workerSessionId?: string | null;
nodeLogs?: string[];
actionPlan?: string;
contextUsage?: ContextUsage;
onClose: () => void;
}
@@ -309,7 +317,7 @@ const tabs: { id: Tab; label: string; Icon: React.FC<{ className?: string }> }[]
{ id: "subagents", label: "Subagents", Icon: ({ className }) => <Bot className={className} /> },
];
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, contextUsage, onClose }: NodeDetailPanelProps) {
const [activeTab, setActiveTab] = useState<Tab>("overview");
const [realTools, setRealTools] = useState<ToolInfo[] | null>(null);
const [realCriteria, setRealCriteria] = useState<NodeCriteria | null>(null);
@@ -389,6 +397,43 @@ export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagent
</div>
)}
{/* Context window usage */}
{contextUsage && (
<div className="px-4 py-2 border-b border-border/20 flex-shrink-0">
<div className="flex items-center gap-2 mb-1">
<span className="text-[10px] text-muted-foreground font-medium">Context</span>
<span className="text-[10px] text-muted-foreground/70 ml-auto">
{(contextUsage.estimatedTokens / 1000).toFixed(1)}k / {(contextUsage.maxTokens / 1000).toFixed(0)}k tokens
</span>
</div>
<div className="w-full h-1.5 rounded-full bg-muted/50 overflow-hidden">
<div
className="h-full rounded-full transition-all duration-500 ease-out"
style={{
width: `${Math.min(contextUsage.usagePct, 100)}%`,
backgroundColor: contextUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: contextUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(45,95%,58%)",
}}
/>
</div>
<div className="flex items-center gap-2 mt-1">
<span className="text-[10px] text-muted-foreground/60">{contextUsage.messageCount} messages</span>
<span className="text-[10px] font-medium ml-auto" style={{
color: contextUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: contextUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(45,95%,58%)",
}}>
{contextUsage.usagePct}%
</span>
</div>
</div>
)}
{/* Tab bar */}
<div className="flex border-b border-border/30 flex-shrink-0 px-2 pt-1 overflow-x-auto scrollbar-hide">
{tabs.filter(t => t.id !== "subagents" || (nodeSpec?.sub_agents && nodeSpec.sub_agents.length > 0)).map(tab => (
@@ -0,0 +1,413 @@
import { memo, useState, useRef, useEffect } from "react";
import { ChevronDown, ChevronUp, Cpu } from "lucide-react";
import type { ChatMessage, ContextUsageEntry } from "@/components/ChatPanel";
import MarkdownContent from "@/components/MarkdownContent";
// ---------------------------------------------------------------------------
// Shared helpers
// ---------------------------------------------------------------------------
const workerColor = "hsl(220,60%,55%)";
const SUBAGENT_COLORS = [
"hsl(220,60%,55%)",
"hsl(260,50%,55%)",
"hsl(180,50%,45%)",
"hsl(30,70%,50%)",
"hsl(340,55%,50%)",
"hsl(150,45%,45%)",
"hsl(45,80%,50%)",
"hsl(290,45%,55%)",
];
function colorForIndex(i: number): string {
return SUBAGENT_COLORS[i % SUBAGENT_COLORS.length];
}
function subagentLabel(nodeId: string): string {
const parts = nodeId.split(":subagent:");
const raw = parts.length >= 2 ? parts[1] : nodeId;
return raw
.replace(/:\d+$/, "") // strip instance suffix like ":3"
.replace(/[_-]/g, " ")
.replace(/\b\w/g, (c) => c.toUpperCase())
.trim();
}
function last<T>(arr: T[]): T | undefined {
return arr[arr.length - 1];
}
export interface SubagentGroup {
nodeId: string;
messages: ChatMessage[];
contextUsage?: ContextUsageEntry;
}
interface ParallelSubagentBubbleProps {
groups: SubagentGroup[];
groupId: string;
}
// ---------------------------------------------------------------------------
// Thermometer — vertical context gauge on right edge of each pane
// ---------------------------------------------------------------------------
// ---------------------------------------------------------------------------
// Tool overlay — shown when a tool_status message is active (not all done)
// ---------------------------------------------------------------------------
function ToolOverlay({
toolName,
color,
visible,
}: {
toolName: string;
color: string;
visible: boolean;
}) {
return (
<div
className="absolute inset-0 top-[22px] flex items-center justify-center transition-opacity duration-200 z-10"
style={{
background: "rgba(8,8,14,0.82)",
opacity: visible ? 1 : 0,
pointerEvents: visible ? "auto" : "none",
}}
>
<div className="text-center px-3 py-2 rounded-md border" style={{ borderColor: `${color}40` }}>
<div className="text-[10px] font-medium" style={{ color }}>
{toolName}
</div>
<div className="text-[11px] mt-0.5" style={{ color }}>
{visible ? "..." : "\u2713"}
</div>
</div>
</div>
);
}
// ---------------------------------------------------------------------------
// Single tmux pane
// ---------------------------------------------------------------------------
function MuxPane({
group,
index,
label,
isFocused,
isZoomed,
onClickTitle,
}: {
group: SubagentGroup;
index: number;
label: string;
isFocused: boolean;
isZoomed: boolean;
onClickTitle: () => void;
}) {
const bodyRef = useRef<HTMLDivElement>(null);
const stickRef = useRef(true);
const color = colorForIndex(index);
const pct = group.contextUsage?.usagePct ?? 0;
const streamMsgs = group.messages.filter((m) => m.type !== "tool_status");
const latestContent = last(streamMsgs)?.content ?? "";
const msgCount = streamMsgs.length;
// Detect active tool and finished state from latest tool_status
const latestTool = last(
group.messages.filter((m) => m.type === "tool_status")
);
let activeToolName = "";
let toolRunning = false;
let isFinished = false;
if (latestTool) {
try {
const parsed = JSON.parse(latestTool.content);
const tools: { name: string; done: boolean }[] = parsed.tools || [];
const allDone = parsed.allDone as boolean | undefined;
const running = tools.find((t) => !t.done);
if (running) {
activeToolName = running.name;
toolRunning = true;
}
// Finished when all tools are done and one of them is set_output
// or report_to_parent (terminal tool calls)
if (allDone && tools.length > 0) {
const hasTerminal = tools.some(
(t) =>
t.done &&
(t.name === "set_output" || t.name === "report_to_parent")
);
if (hasTerminal) isFinished = true;
}
} catch {
/* ignore */
}
}
// Auto-scroll
useEffect(() => {
if (stickRef.current && bodyRef.current) {
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
}
}, [latestContent]);
const handleScroll = () => {
const el = bodyRef.current;
if (!el) return;
stickRef.current = el.scrollHeight - el.scrollTop - el.clientHeight < 30;
};
return (
<div
className="flex flex-col min-h-0 overflow-hidden relative transition-all duration-200"
style={{
borderWidth: 1,
borderStyle: "solid",
borderColor: isFocused && !isFinished ? `${color}60` : "transparent",
opacity: isFinished ? 0.4 : isFocused || isZoomed ? 1 : 0.55,
...(isZoomed
? { gridColumn: "1 / -1", gridRow: "1 / -1", zIndex: 10 }
: {}),
}}
>
{/* Title bar */}
<div
className="flex items-center gap-1.5 px-2 py-[3px] flex-shrink-0 cursor-pointer select-none"
style={{ background: "#0e0e16", borderBottom: "1px solid #1a1a2a" }}
onClick={onClickTitle}
>
{isFinished ? (
<span className="text-[8px] flex-shrink-0 leading-none" style={{ color: "#4a4" }}>&#10003;</span>
) : (
<div
className="w-[6px] h-[6px] rounded-full flex-shrink-0"
style={{ background: color }}
/>
)}
<span className="text-[9px] flex-shrink-0" style={{ color: isFinished ? "#555" : color }}>
{label}
</span>
<span className="flex-1" />
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
{msgCount}
</span>
<div
className="w-[36px] h-[3px] rounded-full overflow-hidden flex-shrink-0"
style={{ background: "#1a1a2a" }}
>
<div
className="h-full rounded-full transition-all duration-500"
style={{
width: `${Math.min(pct, 100)}%`,
backgroundColor:
pct >= 80 ? "hsl(0,65%,55%)" : pct >= 50 ? "hsl(35,90%,55%)" : color,
}}
/>
</div>
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
{pct}%
</span>
</div>
{/* Body */}
<div
ref={bodyRef}
onScroll={handleScroll}
className="flex-1 min-h-0 overflow-y-auto px-2 py-1 text-[10px] leading-[1.7]"
style={{ background: "#08080e", color: "#555", fontFamily: "monospace" }}
>
{latestContent ? (
<div style={{ color: "#ccc" }}>
<MarkdownContent content={latestContent} />
</div>
) : (
<span style={{ color: "#333" }}>waiting...</span>
)}
{/* Blinking cursor — hidden when finished */}
{!isFinished && (
<span
className="inline-block w-[6px] h-[11px] align-middle ml-0.5"
style={{
background: color,
animation: "cursorBlink 1s step-end infinite",
}}
/>
)}
</div>
{/* Tool overlay */}
<ToolOverlay
toolName={activeToolName}
color={color}
visible={toolRunning}
/>
</div>
);
}
// ---------------------------------------------------------------------------
// Main component
// ---------------------------------------------------------------------------
const ParallelSubagentBubble = memo(
function ParallelSubagentBubble({ groups }: ParallelSubagentBubbleProps) {
const [expanded, setExpanded] = useState(false);
const [zoomedIdx, setZoomedIdx] = useState<number | null>(null);
// Labels with instance numbers for duplicates
const labels: string[] = (() => {
const countByBase = new Map<string, number>();
const bases = groups.map((g) => subagentLabel(g.nodeId));
for (const b of bases)
countByBase.set(b, (countByBase.get(b) ?? 0) + 1);
const idxByBase = new Map<string, number>();
return bases.map((b) => {
if ((countByBase.get(b) ?? 1) <= 1) return b;
const idx = (idxByBase.get(b) ?? 0) + 1;
idxByBase.set(b, idx);
return `${b} #${idx}`;
});
})();
// Latest-active pane
const latestIdx = groups.reduce<number>((best, g, i) => {
const filtered = g.messages.filter((m) => m.type !== "tool_status");
const lm = last(filtered);
if (!lm) return best;
if (best < 0) return i;
const bm = last(
groups[best].messages.filter((m) => m.type !== "tool_status")
);
if (!bm) return i;
return (lm.createdAt ?? 0) >= (bm.createdAt ?? 0) ? i : best;
}, -1);
// Per-group finished detection (same logic as MuxPane)
const finishedFlags = groups.map((g) => {
const lt = last(g.messages.filter((m) => m.type === "tool_status"));
if (!lt) return false;
try {
const p = JSON.parse(lt.content);
const tools: { name: string; done: boolean }[] = p.tools || [];
if (!p.allDone || tools.length === 0) return false;
return tools.some(
(t) => t.done && (t.name === "set_output" || t.name === "report_to_parent")
);
} catch { return false; }
});
const activeCount = finishedFlags.filter((f) => !f).length;
if (groups.length === 0) return null;
// Grid sizing: 2 columns, auto rows capped at a fixed height
const rows = Math.ceil(groups.length / 2);
const gridHeight = expanded
? Math.min(rows * 200, 480)
: Math.min(rows * 100, 240);
return (
<div className="flex gap-3">
{/* Left icon */}
<div
className="flex-shrink-0 w-7 h-7 rounded-xl flex items-center justify-center mt-1"
style={{
backgroundColor: `${workerColor}18`,
border: `1.5px solid ${workerColor}35`,
}}
>
<Cpu className="w-3.5 h-3.5" style={{ color: workerColor }} />
</div>
<div className="flex-1 min-w-0 max-w-[90%]">
{/* Header */}
<div className="flex items-center gap-2 mb-1">
<span className="font-medium text-xs" style={{ color: workerColor }}>
{groups.length === 1 ? "Sub-agent" : "Parallel Agents"}
</span>
<span className="text-[10px] font-medium px-1.5 py-0.5 rounded-md bg-muted text-muted-foreground">
{activeCount > 0 ? `${activeCount} running` : `${groups.length} done`}
</span>
<button
onClick={() => {
setExpanded((v) => !v);
setZoomedIdx(null);
}}
className="ml-auto text-muted-foreground/60 hover:text-muted-foreground transition-colors p-0.5 rounded"
title={expanded ? "Collapse" : "Expand"}
>
{expanded ? (
<ChevronUp className="w-3.5 h-3.5" />
) : (
<ChevronDown className="w-3.5 h-3.5" />
)}
</button>
</div>
{/* Mux frame */}
<div
className="rounded-lg overflow-hidden"
style={{
border: "2px solid #1a1a2a",
background: "#08080e",
}}
>
{/* Grid */}
<div
className="grid gap-px"
style={{
gridTemplateColumns:
groups.length === 1 ? "1fr" : "1fr 1fr",
gridTemplateRows: `repeat(${rows}, 1fr)`,
height: gridHeight,
background: "#111",
}}
>
{groups.map((group, i) => (
<MuxPane
key={group.nodeId}
group={group}
index={i}
label={labels[i]}
isFocused={latestIdx === i}
isZoomed={zoomedIdx === i}
onClickTitle={() =>
setZoomedIdx(zoomedIdx === i ? null : i)
}
/>
))}
</div>
</div>
</div>
</div>
);
},
(prev, next) =>
prev.groupId === next.groupId &&
prev.groups.length === next.groups.length &&
prev.groups.every(
(g, i) =>
g.nodeId === next.groups[i].nodeId &&
g.messages.length === next.groups[i].messages.length &&
last(g.messages)?.content === last(next.groups[i].messages)?.content &&
g.contextUsage?.usagePct === next.groups[i].contextUsage?.usagePct
)
);
export default ParallelSubagentBubble;
// Injected as a global style (keyframes can't be inline)
if (typeof document !== "undefined") {
const id = "parallel-subagent-keyframes";
if (!document.getElementById(id)) {
const style = document.createElement("style");
style.id = id;
style.textContent = `
@keyframes cursorBlink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
@keyframes thermoPulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.4; } }
`;
document.head.appendChild(style);
}
}
+6 -2
View File
@@ -62,7 +62,7 @@ export function sseEventToChatMessage(
const innerSuffix = innerTurn != null && innerTurn > 0 ? `-t${innerTurn}` : "";
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
if (!snapshot) return null;
if (!snapshot.trim()) return null;
return {
id: `stream-${iterIdKey}${innerSuffix}-${event.node_id}`,
agent: agentDisplayName || event.node_id || "Agent",
@@ -72,6 +72,8 @@ export function sseEventToChatMessage(
role: "worker",
thread,
createdAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
};
}
@@ -100,7 +102,7 @@ export function sseEventToChatMessage(
const llmInnerSuffix = llmInnerTurn != null && llmInnerTurn > 0 ? `-t${llmInnerTurn}` : "";
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
if (!snapshot) return null;
if (!snapshot.trim()) return null;
return {
id: `stream-${idKey}${llmInnerSuffix}-${event.node_id}`,
agent: event.node_id || "Agent",
@@ -110,6 +112,8 @@ export function sseEventToChatMessage(
role: "worker",
thread,
createdAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
};
}
+77 -9
View File
@@ -1,7 +1,7 @@
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
import ReactDOM from "react-dom";
import { useSearchParams, useNavigate } from "react-router-dom";
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X } from "lucide-react";
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X, FolderOpen } from "lucide-react";
import type { GraphNode, NodeStatus } from "@/components/graph-types";
import DraftGraph from "@/components/DraftGraph";
import ChatPanel, { type ChatMessage } from "@/components/ChatPanel";
@@ -352,6 +352,10 @@ interface AgentBackendState {
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
/** Whether the pending question came from queen or worker */
pendingQuestionSource: "queen" | "worker" | null;
/** Per-node context window usage (from context_usage_updated events) */
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
/** Whether the queen's LLM supports image content — false disables the attach button */
queenSupportsImages: boolean;
}
function defaultAgentState(): AgentBackendState {
@@ -389,6 +393,8 @@ function defaultAgentState(): AgentBackendState {
pendingOptions: null,
pendingQuestions: null,
pendingQuestionSource: null,
contextUsage: {},
queenSupportsImages: true,
};
}
@@ -630,6 +636,10 @@ export default function Workspace() {
// it was created in (avoids stale-closure when phase change and message
// events arrive in the same React batch).
const queenPhaseRef = useRef<Record<string, string>>({});
// Accumulated queen text across inner_turns within the same iteration.
// Key: `${agentType}:${execution_id}:${iteration}`, value: { [inner_turn]: snapshot }.
// This lets us merge all inner_turn text into one chat bubble per iteration.
const queenIterTextRef = useRef<Record<string, Record<number, string>>>({});
// Timestamp when designingDraft was set — used to enforce minimum spinner duration.
const designingDraftSinceRef = useRef<Record<string, number>>({});
const designingDraftTimerRef = useRef<Record<string, ReturnType<typeof setTimeout>>>({});
@@ -916,6 +926,7 @@ export default function Workspace() {
queenReady: true,
queenPhase: qPhase,
queenBuilding: qPhase === "building",
queenSupportsImages: liveSession.queen_supports_images !== false,
// Restore flowchart overlay from persisted events
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
@@ -1115,6 +1126,7 @@ export default function Workspace() {
displayName,
queenPhase: initialPhase,
queenBuilding: initialPhase === "building",
queenSupportsImages: session.queen_supports_images !== false,
// Restore flowchart overlay from persisted events
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
@@ -1707,14 +1719,29 @@ export default function Workspace() {
if (isQueen) console.log('[QUEEN] chatMsg:', chatMsg?.id, chatMsg?.content?.slice(0, 50), 'turn:', currentTurn);
if (chatMsg && !suppressQueenMessages) {
// Queen emits multiple client_output_delta / llm_text_delta snapshots
// across iterations and inner tool-loop turns. Build a stable ID that
// groups streaming deltas for the *same* output (same execution +
// iteration + inner_turn) into one bubble, while keeping distinct
// outputs as separate bubbles so earlier text isn't overwritten.
// across iterations and inner tool-loop turns. Merge all inner_turns
// within the same iteration into ONE bubble so the queen's multi-step
// tool loop (text → tool → text → tool → text) appears as one cohesive
// message rather than many small fragments.
if (isQueen && (event.type === "client_output_delta" || event.type === "llm_text_delta") && event.execution_id) {
const iter = event.data?.iteration ?? 0;
const inner = event.data?.inner_turn ?? 0;
chatMsg.id = `queen-stream-${event.execution_id}-${iter}-${inner}`;
const inner = (event.data?.inner_turn as number) ?? 0;
const iterKey = `${agentType}:${event.execution_id}:${iter}`;
// Store the latest snapshot for this inner_turn
if (!queenIterTextRef.current[iterKey]) {
queenIterTextRef.current[iterKey] = {};
}
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
queenIterTextRef.current[iterKey][inner] = snapshot;
// Concatenate all inner_turn snapshots in order
const parts = queenIterTextRef.current[iterKey];
const sortedInners = Object.keys(parts).map(Number).sort((a, b) => a - b);
chatMsg.content = sortedInners.map(k => parts[k]).join("\n");
// Single ID per iteration — no inner_turn in the ID
chatMsg.id = `queen-stream-${event.execution_id}-${iter}`;
}
if (isQueen) {
chatMsg.role = role;
@@ -1989,6 +2016,8 @@ export default function Workspace() {
role,
thread: agentType,
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
});
return {
...prev,
@@ -2060,6 +2089,8 @@ export default function Workspace() {
role,
thread: agentType,
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
});
return {
...prev,
@@ -2136,6 +2167,29 @@ export default function Workspace() {
}
break;
case "context_usage_updated": {
const streamKey = isQueen ? "__queen__" : (event.node_id || streamId);
const usagePct = (event.data?.usage_pct as number) ?? 0;
const messageCount = (event.data?.message_count as number) ?? 0;
const estimatedTokens = (event.data?.estimated_tokens as number) ?? 0;
const maxTokens = (event.data?.max_context_tokens as number) ?? 0;
setAgentStates(prev => {
const state = prev[agentType];
if (!state) return prev;
return {
...prev,
[agentType]: {
...state,
contextUsage: {
...state.contextUsage,
[streamKey]: { usagePct, messageCount, estimatedTokens, maxTokens },
},
},
};
});
}
break;
case "node_action_plan":
if (!isQueen && event.node_id) {
const plan = (event.data?.plan as string) || "";
@@ -2564,7 +2618,7 @@ export default function Workspace() {
});
// --- handleSend ---
const handleSend = useCallback((text: string, thread: string) => {
const handleSend = useCallback((text: string, thread: string, images?: import("@/components/ChatPanel").ImageContent[]) => {
if (!activeSession) return;
const state = agentStates[activeWorker];
@@ -2630,6 +2684,7 @@ export default function Workspace() {
const userMsg: ChatMessage = {
id: makeId(), agent: "You", agentColor: "",
content: text, timestamp: "", type: "user", thread, createdAt: Date.now(),
images,
};
setSessionsByAgent(prev => ({
...prev,
@@ -2641,7 +2696,7 @@ export default function Workspace() {
updateAgentState(activeWorker, { isTyping: true, queenIsTyping: true });
if (state?.sessionId && state?.ready) {
executionApi.chat(state.sessionId, text).catch((err: unknown) => {
executionApi.chat(state.sessionId, text, images).catch((err: unknown) => {
const errMsg = err instanceof Error ? err.message : String(err);
const errorChatMsg: ChatMessage = {
id: makeId(), agent: "System", agentColor: "",
@@ -3057,6 +3112,16 @@ export default function Workspace() {
<KeyRound className="w-3.5 h-3.5" />
Credentials
</button>
{activeAgentState?.sessionId && (
<button
onClick={() => sessionsApi.revealFolder(activeAgentState.sessionId!).catch(() => {})}
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/50 transition-colors flex-shrink-0"
title="Open session data folder"
>
<FolderOpen className="w-3.5 h-3.5" />
Data
</button>
)}
</TopBar>
{/* Main content area */}
@@ -3174,6 +3239,8 @@ export default function Workspace() {
}
onMultiQuestionSubmit={handleMultiQuestionAnswer}
onQuestionDismiss={handleQuestionDismiss}
contextUsage={activeAgentState?.contextUsage}
supportsImages={activeAgentState?.queenSupportsImages ?? true}
/>
)}
</div>
@@ -3377,6 +3444,7 @@ export default function Workspace() {
workerSessionId={null}
nodeLogs={activeAgentState?.nodeLogs[resolvedSelectedNode.id] || []}
actionPlan={activeAgentState?.nodeActionPlans[resolvedSelectedNode.id]}
contextUsage={activeAgentState?.contextUsage[resolvedSelectedNode.id]}
onClose={() => setSelectedNode(null)}
/>
)}
+5 -1
View File
@@ -8,7 +8,7 @@ dependencies = [
"pydantic>=2.0",
"anthropic>=0.40.0",
"httpx>=0.27.0",
"litellm>=1.81.0",
"litellm==1.81.7", # pinned: supply chain attack in >=1.82.7 (adenhq/hive#6783)
"mcp>=1.0.0",
"fastmcp>=2.0.0",
"croniter>=1.4.0",
@@ -62,6 +62,10 @@ lint.isort.section-order = [
"first-party",
"local-folder",
]
[tool.pytest.ini_options]
filterwarnings = [
"ignore::DeprecationWarning:litellm.*"
]
[dependency-groups]
dev = [
@@ -0,0 +1,243 @@
"""Tests for AgentRunner MCP registry integration."""
import json
from pathlib import Path
from framework.graph.edge import GraphSpec
from framework.graph.goal import Goal
from framework.graph.node import NodeSpec
from framework.runner.mcp_registry import MCPRegistry
from framework.runner.runner import AgentRunner
def _make_graph() -> GraphSpec:
return GraphSpec(
id="test-graph",
goal_id="goal-1",
entry_node="start",
terminal_nodes=["start"],
nodes=[NodeSpec(id="start", name="Start", description="Start node")],
edges=[],
)
class _FakeRegistry:
def __init__(self, returned_configs):
self._returned_configs = returned_configs
self.initialize_calls = 0
self.loaded_paths: list[Path] = []
def initialize(self) -> None:
self.initialize_calls += 1
def load_agent_selection(self, agent_path: Path):
self.loaded_paths.append(agent_path)
return list(self._returned_configs)
def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text('{"include": ["jira"]}', encoding="utf-8")
fake_registry = _FakeRegistry(
[
{
"name": "jira",
"transport": "http",
"url": "http://localhost:4010",
"headers": {},
"description": "Jira",
}
]
)
registered: list[dict] = []
monkeypatch.setattr("framework.runner.mcp_registry.MCPRegistry", lambda: fake_registry)
monkeypatch.setattr(
"framework.runner.runner.run_preload_validation",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
registered.append(server_config) or 1
),
)
AgentRunner(
agent_path=agent_dir,
graph=_make_graph(),
goal=Goal(id="goal-1", name="Goal", description="desc"),
storage_path=tmp_path / "storage",
interactive=False,
skip_credential_validation=True,
)
assert fake_registry.initialize_calls == 1
assert fake_registry.loaded_paths == [agent_dir]
assert [config["name"] for config in registered] == ["jira"]
def test_agent_runner_skips_registry_when_no_servers_selected(tmp_path, monkeypatch):
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
fake_registry = _FakeRegistry([])
registered: list[dict] = []
monkeypatch.setattr("framework.runner.mcp_registry.MCPRegistry", lambda: fake_registry)
monkeypatch.setattr(
"framework.runner.runner.run_preload_validation",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
registered.append(server_config) or 1
),
)
AgentRunner(
agent_path=agent_dir,
graph=_make_graph(),
goal=Goal(id="goal-1", name="Goal", description="desc"),
storage_path=tmp_path / "storage",
interactive=False,
skip_credential_validation=True,
)
assert fake_registry.initialize_calls == 1
assert fake_registry.loaded_paths == [agent_dir]
assert registered == []
def test_agent_runner_logs_actual_registry_load_results(tmp_path, monkeypatch):
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text('{"include": ["jira", "slack"]}', encoding="utf-8")
fake_registry = _FakeRegistry(
[
{"name": "jira", "transport": "http", "url": "http://localhost:4010"},
{"name": "slack", "transport": "http", "url": "http://localhost:4020"},
]
)
log_messages: list[str] = []
monkeypatch.setattr("framework.runner.mcp_registry.MCPRegistry", lambda: fake_registry)
monkeypatch.setattr(
"framework.runner.runner.run_preload_validation",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.load_registry_servers",
lambda self, server_configs: [
{"server": "jira", "status": "loaded", "tools_loaded": 2, "skipped_reason": None},
{
"server": "slack",
"status": "skipped",
"tools_loaded": 0,
"skipped_reason": "registered 0 tools",
},
],
)
monkeypatch.setattr(
"framework.runner.runner.logger.info",
lambda message, *args: log_messages.append(message % args if args else str(message)),
)
AgentRunner(
agent_path=agent_dir,
graph=_make_graph(),
goal=Goal(id="goal-1", name="Goal", description="desc"),
storage_path=tmp_path / "storage",
interactive=False,
skip_credential_validation=True,
)
assert any("Loaded 1/2 MCP registry server(s)" in message for message in log_messages)
assert any("Skipped MCP registry servers" in message for message in log_messages)
def test_agent_runner_survives_malformed_registry_json(tmp_path, monkeypatch):
"""Agent startup must not crash when mcp_registry.json has invalid JSON."""
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text("{bad json", encoding="utf-8")
warnings: list[str] = []
monkeypatch.setattr(
"framework.runner.runner.run_preload_validation",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.runner.logger.warning",
lambda message, *args: warnings.append(message % args if args else str(message)),
)
AgentRunner(
agent_path=agent_dir,
graph=_make_graph(),
goal=Goal(id="goal-1", name="Goal", description="desc"),
storage_path=tmp_path / "storage",
interactive=False,
skip_credential_validation=True,
)
assert any("Failed to load MCP registry servers" in w for w in warnings)
def test_integration_real_registry_to_agent_runner(tmp_path, monkeypatch):
"""Integration: real MCPRegistry on disk → mcp_registry.json → AgentRunner."""
# Set up a real registry with a local server
registry_base = tmp_path / "mcp_registry"
registry = MCPRegistry(base_path=registry_base)
registry.initialize()
registry.add_local(name="jira", transport="http", url="http://localhost:4010")
# Write mcp_registry.json in the agent dir
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text(
json.dumps({"include": ["jira"]}), encoding="utf-8"
)
# Patch MCPRegistry to use our tmp_path base, but keep real logic
original_init = MCPRegistry.__init__
def patched_init(self, base_path=None):
original_init(self, base_path=registry_base)
monkeypatch.setattr(MCPRegistry, "__init__", patched_init)
monkeypatch.setattr(
"framework.runner.runner.run_preload_validation",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
registered: list[dict] = []
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
registered.append(server_config) or 1
),
)
AgentRunner(
agent_path=agent_dir,
graph=_make_graph(),
goal=Goal(id="goal-1", name="Goal", description="desc"),
storage_path=tmp_path / "storage",
interactive=False,
skip_credential_validation=True,
)
assert len(registered) == 1
assert registered[0]["name"] == "jira"
assert registered[0]["transport"] == "http"
assert registered[0]["url"] == "http://localhost:4010"
+172
View File
@@ -0,0 +1,172 @@
"""Integration test: Run a real EventLoopNode against the Antigravity backend.
Run: .venv/bin/python core/tests/test_antigravity_eventloop.py
Requires:
- ~/.hive/antigravity-accounts.json with valid credentials
(run 'uv run python core/antigravity_auth.py auth account add' to authenticate)
"""
import asyncio
import logging
import sys
from unittest.mock import MagicMock
sys.path.insert(0, "core")
logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")
# Show our provider's retry/stream logs
logging.getLogger("framework.llm.litellm").setLevel(logging.DEBUG)
from framework.config import RuntimeConfig # noqa: E402
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory # noqa: E402
from framework.llm.litellm import LiteLLMProvider # noqa: E402
def make_provider() -> LiteLLMProvider:
cfg = RuntimeConfig()
if not cfg.api_key:
print("ERROR: No Antigravity token found.")
print(" 1. Run 'antigravity-auth accounts add' to authenticate.")
print(" 2. Run 'antigravity-auth serve' to start the local proxy.")
print(" 3. Configure Hive: run quickstart.sh and select option 7 (Antigravity).")
sys.exit(1)
print(f"Model : {cfg.model}")
print(f"Base : {cfg.api_base}")
print(f"Antigravity : {'localhost:8069' in (cfg.api_base or '')}")
return LiteLLMProvider(
model=cfg.model,
api_key=cfg.api_key,
api_base=cfg.api_base,
**cfg.extra_kwargs,
)
def make_context(
llm: LiteLLMProvider,
*,
node_id: str = "test",
system_prompt: str = "You are a helpful assistant.",
output_keys: list[str] | None = None,
) -> NodeContext:
if output_keys is None:
output_keys = ["answer"]
spec = NodeSpec(
id=node_id,
name="Test Node",
description="Integration test node",
node_type="event_loop",
output_keys=output_keys,
system_prompt=system_prompt,
)
runtime = MagicMock()
runtime.start_run = MagicMock(return_value="run-1")
runtime.decide = MagicMock(return_value="dec-1")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
memory = SharedMemory()
return NodeContext(
runtime=runtime,
node_id=node_id,
node_spec=spec,
memory=memory,
input_data={},
llm=llm,
available_tools=[],
max_tokens=4096,
)
async def run_test(
name: str, llm: LiteLLMProvider, system: str, output_keys: list[str]
) -> NodeResult:
print(f"\n{'=' * 60}")
print(f"TEST: {name}")
print(f"{'=' * 60}")
ctx = make_context(llm, system_prompt=system, output_keys=output_keys)
node = EventLoopNode(config=LoopConfig(max_iterations=3))
try:
result = await node.execute(ctx)
print(f" Success : {result.success}")
print(f" Output : {result.output}")
if result.error:
print(f" Error : {result.error}")
return result
except Exception as e:
print(f" EXCEPTION: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return NodeResult(success=False, error=str(e))
async def main():
llm = make_provider()
print()
# Test 1: Simple text output — the node should call set_output to fill "answer"
r1 = await run_test(
name="Simple text generation",
llm=llm,
system=(
"You are a helpful assistant. When asked a question, use the "
"set_output tool to store your answer in the 'answer' key. "
"Keep answers short (1-2 sentences)."
),
output_keys=["answer"],
)
# Test 2: If test 1 failed, try bare stream() to isolate the issue
if not r1.success:
print(f"\n{'=' * 60}")
print("FALLBACK: Testing bare provider.stream() directly")
print(f"{'=' * 60}")
try:
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
ToolCallEvent,
)
text = ""
events = []
async for event in llm.stream(
messages=[{"role": "user", "content": "Say hello in 3 words."}],
):
events.append(type(event).__name__)
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, FinishEvent):
print(
f" Finish: stop={event.stop_reason}"
f" in={event.input_tokens}"
f" out={event.output_tokens}"
)
elif isinstance(event, StreamErrorEvent):
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
elif isinstance(event, ToolCallEvent):
print(f" ToolCall: {event.tool_name}")
print(f" Text : {text!r}")
print(f" Events : {events}")
print(f" RESULT : {'OK' if text else 'EMPTY'}")
except Exception as e:
print(f" EXCEPTION: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print(f"\n{'=' * 60}")
print("DONE")
print(f"{'=' * 60}")
if __name__ == "__main__":
asyncio.run(main())
+58
View File
@@ -0,0 +1,58 @@
"""Tests for LLM model capability checks."""
from __future__ import annotations
import pytest
from framework.llm.capabilities import supports_image_tool_results
class TestSupportsImageToolResults:
"""Verify the deny-list correctly identifies models that can't handle images."""
@pytest.mark.parametrize(
"model",
[
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"claude-haiku-4-5-20251001",
"gemini/gemini-1.5-pro",
"google/gemini-1.5-flash",
"mistral/mistral-large",
"groq/llama3-70b",
"together/meta-llama/Llama-3-70b",
"fireworks_ai/llama-v3-70b",
"azure/gpt-4o",
"kimi/claude-sonnet-4-20250514",
"hive/claude-sonnet-4-20250514",
],
)
def test_supported_models(self, model: str):
assert supports_image_tool_results(model) is True
@pytest.mark.parametrize(
"model",
[
"deepseek/deepseek-chat",
"deepseek/deepseek-coder",
"deepseek-chat",
"deepseek-reasoner",
"ollama/llama3",
"ollama/mistral",
"ollama_chat/llama3",
"lm_studio/my-model",
"vllm/meta-llama/Llama-3-70b",
"llamacpp/model",
"cerebras/llama3-70b",
],
)
def test_unsupported_models(self, model: str):
assert supports_image_tool_results(model) is False
def test_case_insensitive(self):
assert supports_image_tool_results("DeepSeek/deepseek-chat") is False
assert supports_image_tool_results("OLLAMA/llama3") is False
assert supports_image_tool_results("GPT-4o") is True
+68
View File
@@ -0,0 +1,68 @@
import io
import threading
import time
import codex_oauth
def _redirect_url(state: str, code: str) -> str:
return f"{codex_oauth.REDIRECT_URI}?code={code}&state={state}"
def test_wait_for_code_accepts_valid_manual_input_after_invalid_entry():
state = "expected-state"
stdin = io.StringIO(f"not a valid code\n{_redirect_url(state, 'manual-code')}\n")
code = codex_oauth.wait_for_code_from_callback_or_stdin(
state,
[None],
threading.Event(),
timeout_secs=0.5,
poll_interval=0.01,
stdin=stdin,
)
assert code == "manual-code"
def test_wait_for_code_returns_callback_when_stdin_reader_fails():
class BrokenStdin:
def readline(self) -> str:
raise OSError("stdin unavailable")
state = "expected-state"
callback_result: list[str | None] = [None]
callback_done = threading.Event()
def resolve_callback() -> None:
time.sleep(0.02)
callback_result[0] = "callback-code"
callback_done.set()
threading.Thread(target=resolve_callback, daemon=True).start()
code = codex_oauth.wait_for_code_from_callback_or_stdin(
state,
callback_result,
callback_done,
timeout_secs=0.5,
poll_interval=0.01,
stdin=BrokenStdin(),
)
assert code == "callback-code"
def test_open_browser_uses_windows_startfile(monkeypatch):
calls: list[str] = []
monkeypatch.setattr(codex_oauth.platform, "system", lambda: "Windows")
monkeypatch.setattr(codex_oauth.os, "startfile", calls.append, raising=False)
def fail_popen(*args, **kwargs):
raise AssertionError("Windows browser launch should not go through cmd /c start")
monkeypatch.setattr(codex_oauth.subprocess, "Popen", fail_popen)
assert codex_oauth.open_browser("https://example.com/path?a=1&b=2") is True
assert calls == ["https://example.com/path?a=1&b=2"]
+228
View File
@@ -0,0 +1,228 @@
"""Unit tests for MCP client transport and reconnect behavior."""
from types import SimpleNamespace
import httpx
import pytest
from framework.runner import mcp_client as mcp_client_module
from framework.runner.mcp_client import MCPClient, MCPServerConfig, MCPTool
class _FakeResponse:
def __init__(self, payload=None):
self._payload = payload or {}
def raise_for_status(self) -> None:
"""Pretend the request succeeded."""
def json(self):
return self._payload
class _FakeHttpClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.get_calls: list[str] = []
self.closed = False
def get(self, path: str) -> _FakeResponse:
self.get_calls.append(path)
return _FakeResponse()
def close(self) -> None:
self.closed = True
def test_connect_unix_transport_uses_socket_path(monkeypatch):
created = {}
class FakeHTTPTransport:
def __init__(self, *, uds: str):
created["uds"] = uds
self.uds = uds
def fake_client_factory(**kwargs):
client = _FakeHttpClient(**kwargs)
created["client"] = client
return client
monkeypatch.setattr(mcp_client_module.httpx, "HTTPTransport", FakeHTTPTransport)
monkeypatch.setattr(mcp_client_module.httpx, "Client", fake_client_factory)
monkeypatch.setattr(MCPClient, "_discover_tools", lambda self: None)
client = MCPClient(
MCPServerConfig(
name="unix-server",
transport="unix",
url="http://localhost",
socket_path="/tmp/test.sock",
)
)
client.connect()
assert created["uds"] == "/tmp/test.sock"
assert client._http_client is created["client"] # noqa: SLF001 - direct unit test
assert created["client"].kwargs["base_url"] == "http://localhost"
assert created["client"].get_calls == ["/health"]
client.disconnect()
assert created["client"].closed is True
def test_connect_sse_and_list_tools(monkeypatch):
pytest.importorskip("mcp")
sse_module = pytest.importorskip("mcp.client.sse")
import mcp
contexts = []
class FakeSSEContext:
def __init__(self, url: str, headers: dict[str, str] | None, timeout: float):
self.url = url
self.headers = headers
self.timeout = timeout
self.exited = False
async def __aenter__(self):
return "read-stream", "write-stream"
async def __aexit__(self, exc_type, exc, tb):
self.exited = True
class FakeSession:
def __init__(self, read_stream, write_stream):
self.read_stream = read_stream
self.write_stream = write_stream
self.closed = False
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
self.closed = True
async def initialize(self):
"""Pretend session initialization succeeded."""
async def list_tools(self):
return SimpleNamespace(
tools=[
SimpleNamespace(
name="search",
description="Search docs",
inputSchema={"type": "object"},
)
]
)
def fake_sse_client(url: str, headers=None, timeout=5, **_kwargs):
context = FakeSSEContext(url=url, headers=headers, timeout=timeout)
contexts.append(context)
return context
monkeypatch.setattr(sse_module, "sse_client", fake_sse_client)
monkeypatch.setattr(mcp, "ClientSession", FakeSession)
client = MCPClient(
MCPServerConfig(
name="sse-server",
transport="sse",
url="http://localhost/sse",
headers={"Authorization": "Bearer token"},
)
)
client.connect()
tools = client.list_tools()
assert [tool.name for tool in tools] == ["search"]
assert tools[0].description == "Search docs"
assert contexts[0].url == "http://localhost/sse"
assert contexts[0].headers == {"Authorization": "Bearer token"}
assert contexts[0].timeout == 30.0
client.disconnect()
assert contexts[0].exited is True
def test_call_tool_retries_once_on_connect_error_for_unix(monkeypatch):
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
client._connected = True # noqa: SLF001 - direct unit test
client._tools = { # noqa: SLF001 - direct unit test
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
}
first_error = httpx.ConnectError("first failure")
calls = {"count": 0}
reconnects = []
def fake_call_tool_http(tool_name, arguments):
calls["count"] += 1
if calls["count"] == 1:
raise first_error
return [{"type": "text", "text": f"{tool_name}:{arguments['value']}"}]
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
result = client.call_tool("ping", {"value": "ok"})
assert result == [{"type": "text", "text": "ping:ok"}]
assert calls["count"] == 2
assert reconnects == ["reconnected"]
def test_call_tool_retry_exhausted_raises_original_error_for_unix(monkeypatch):
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
client._connected = True # noqa: SLF001 - direct unit test
client._tools = { # noqa: SLF001 - direct unit test
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
}
first_error = httpx.ConnectError("first failure")
second_error = httpx.ConnectError("second failure")
calls = {"count": 0}
reconnects = []
def fake_call_tool_http(_tool_name, _arguments):
calls["count"] += 1
if calls["count"] == 1:
raise first_error
raise second_error
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
with pytest.raises(httpx.ConnectError) as exc_info:
client.call_tool("ping", {"value": "ok"})
assert exc_info.value is first_error
assert calls["count"] == 2
assert reconnects == ["reconnected"]
def test_call_tool_http_preserves_runtime_error_wrapping(monkeypatch):
client = MCPClient(MCPServerConfig(name="http-server", transport="http"))
client._connected = True # noqa: SLF001 - direct unit test
client._tools = { # noqa: SLF001 - direct unit test
"ping": MCPTool("ping", "Ping tool", {}, "http-server")
}
connect_error = httpx.ConnectError("first failure")
class FailingHttpClient:
def post(self, _path, json):
raise connect_error
client._http_client = FailingHttpClient() # noqa: SLF001 - direct unit test
reconnects = []
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
with pytest.raises(RuntimeError) as exc_info:
client.call_tool("ping", {"value": "ok"})
assert "Failed to call tool via HTTP" in str(exc_info.value)
assert exc_info.value.__cause__ is connect_error
assert reconnects == []
+320
View File
@@ -0,0 +1,320 @@
"""Tests for the shared MCP connection manager."""
import threading
import httpx
import pytest
from framework.runner.mcp_client import MCPServerConfig, MCPTool
from framework.runner.mcp_connection_manager import MCPConnectionManager
class FakeMCPClient:
"""Minimal fake MCP client for connection manager tests."""
instances: list["FakeMCPClient"] = []
def __init__(self, config: MCPServerConfig):
self.config = config
self._connected = False
self.connect_calls = 0
self.disconnect_calls = 0
self.list_tools_calls = 0
self.list_tools_error: Exception | None = None
FakeMCPClient.instances.append(self)
def connect(self) -> None:
self.connect_calls += 1
self._connected = True
def disconnect(self) -> None:
self.disconnect_calls += 1
self._connected = False
def list_tools(self) -> list[MCPTool]:
self.list_tools_calls += 1
if self.list_tools_error is not None:
raise self.list_tools_error
return [MCPTool("ping", "Ping", {"type": "object"}, self.config.name)]
@pytest.fixture
def manager(monkeypatch):
monkeypatch.setattr("framework.runner.mcp_connection_manager.MCPClient", FakeMCPClient)
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FakeMCPClient.instances.clear()
manager = MCPConnectionManager.get_instance()
yield manager
manager.cleanup_all()
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FakeMCPClient.instances.clear()
def test_acquire_returns_same_client_for_same_server_name(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client_one = manager.acquire(config)
client_two = manager.acquire(config)
assert client_one is client_two
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
assert len(FakeMCPClient.instances) == 1
def test_release_with_refcount_above_one_keeps_connection_open(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
manager.acquire(config)
manager.release("shared")
assert client.disconnect_calls == 0
assert manager._pool["shared"] is client # noqa: SLF001 - state assertion for unit test
assert manager._refcounts["shared"] == 1 # noqa: SLF001 - state assertion for unit test
def test_release_last_reference_disconnects_and_removes_from_pool(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
manager.release("shared")
assert client.disconnect_calls == 1
assert "shared" not in manager._pool # noqa: SLF001 - state assertion for unit test
assert "shared" not in manager._refcounts # noqa: SLF001 - state assertion for unit test
def test_concurrent_acquire_and_release_keeps_state_consistent(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
worker_count = 8
acquire_barrier = threading.Barrier(worker_count + 1)
release_barrier = threading.Barrier(worker_count)
acquired_clients: list[FakeMCPClient] = []
acquired_lock = threading.Lock()
def worker() -> None:
acquire_barrier.wait()
client = manager.acquire(config)
with acquired_lock:
acquired_clients.append(client)
release_barrier.wait()
manager.release("shared")
threads = [threading.Thread(target=worker) for _ in range(worker_count)]
for thread in threads:
thread.start()
acquire_barrier.wait()
for thread in threads:
thread.join()
assert len({id(client) for client in acquired_clients}) == 1
assert len(FakeMCPClient.instances) == 1
assert FakeMCPClient.instances[0].disconnect_calls == 1
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
def test_cleanup_all_disconnects_every_pooled_client(manager):
manager.acquire(MCPServerConfig(name="one", transport="stdio", command="echo"))
manager.acquire(MCPServerConfig(name="two", transport="stdio", command="echo"))
manager.cleanup_all()
assert len(FakeMCPClient.instances) == 2
assert all(client.disconnect_calls == 1 for client in FakeMCPClient.instances)
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
assert manager._configs == {} # noqa: SLF001 - state assertion for unit test
def test_reconnect_replaces_client_even_with_existing_refcount(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
original_client = manager.acquire(config)
manager.acquire(config)
replacement = manager.reconnect("shared")
assert replacement is not original_client
assert original_client.disconnect_calls == 1
assert manager._pool["shared"] is replacement # noqa: SLF001 - state assertion for unit test
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
def test_health_check_returns_false_when_server_is_unreachable(manager, monkeypatch):
config = MCPServerConfig(name="shared", transport="http", url="http://localhost:9")
manager.acquire(config)
class FailingHttpClient:
def __init__(self, **_kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, _path: str):
raise httpx.ConnectError("unreachable")
monkeypatch.setattr("framework.runner.mcp_connection_manager.httpx.Client", FailingHttpClient)
assert manager.health_check("shared") is False
def test_health_check_for_stdio_returns_true_when_healthy(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
manager.acquire(config)
assert manager.health_check("shared") is True
def test_health_check_for_stdio_returns_false_on_tools_list_error(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
client.list_tools_error = RuntimeError("broken")
assert manager.health_check("shared") is False
def test_health_check_for_sse_uses_list_tools(manager):
config = MCPServerConfig(name="stream", transport="sse", url="http://localhost:9000/sse")
client = manager.acquire(config)
assert manager.health_check("stream") is True
assert client.list_tools_calls >= 1
def test_health_check_unknown_server_returns_false(manager):
assert manager.health_check("nonexistent") is False
# ── Failure-path tests ──────────────────────────────────────────────
class FailingConnectClient(FakeMCPClient):
"""Client that raises on connect()."""
def connect(self) -> None:
self.connect_calls += 1
raise ConnectionError("connect failed")
class FailingDisconnectClient(FakeMCPClient):
"""Client that raises on disconnect()."""
def disconnect(self) -> None:
self.disconnect_calls += 1
self._connected = False
raise RuntimeError("disconnect failed")
def test_acquire_cleans_up_transition_when_connect_fails(monkeypatch):
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPClient",
FailingConnectClient,
)
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FailingConnectClient.instances = []
mgr = MCPConnectionManager.get_instance()
config = MCPServerConfig(name="broken", transport="stdio", command="echo")
with pytest.raises(ConnectionError, match="connect failed"):
mgr.acquire(config)
# Transition should be cleaned up, not stuck
assert "broken" not in mgr._transitions # noqa: SLF001
assert "broken" not in mgr._pool # noqa: SLF001
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
def test_release_handles_disconnect_failure(monkeypatch):
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPClient",
FailingDisconnectClient,
)
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FailingDisconnectClient.instances = []
mgr = MCPConnectionManager.get_instance()
config = MCPServerConfig(name="flaky", transport="stdio", command="echo")
mgr.acquire(config)
# release should not raise even if disconnect fails
mgr.release("flaky")
# Pool should be cleaned up despite disconnect failure
assert "flaky" not in mgr._pool # noqa: SLF001
assert "flaky" not in mgr._refcounts # noqa: SLF001
assert "flaky" not in mgr._transitions # noqa: SLF001
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
def test_reconnect_handles_old_client_disconnect_failure(monkeypatch):
call_count = 0
class FirstFailsThenWorks(FakeMCPClient):
"""First instance fails disconnect, second works fine."""
def disconnect(self) -> None:
nonlocal call_count
call_count += 1
self.disconnect_calls += 1
self._connected = False
if call_count == 1:
raise RuntimeError("old disconnect failed")
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPClient",
FirstFailsThenWorks,
)
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FirstFailsThenWorks.instances = []
mgr = MCPConnectionManager.get_instance()
config = MCPServerConfig(name="flaky", transport="stdio", command="echo")
original = mgr.acquire(config)
# reconnect should succeed even if old client disconnect fails
replacement = mgr.reconnect("flaky")
assert replacement is not original
assert "flaky" in mgr._pool # noqa: SLF001
assert "flaky" not in mgr._transitions # noqa: SLF001
mgr.cleanup_all()
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
def test_cleanup_all_handles_disconnect_failure(monkeypatch):
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPClient",
FailingDisconnectClient,
)
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FailingDisconnectClient.instances = []
mgr = MCPConnectionManager.get_instance()
mgr.acquire(MCPServerConfig(name="a", transport="stdio", command="echo"))
mgr.acquire(MCPServerConfig(name="b", transport="stdio", command="echo"))
# cleanup_all should not raise even if disconnects fail
mgr.cleanup_all()
assert mgr._pool == {} # noqa: SLF001
assert mgr._refcounts == {} # noqa: SLF001
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
def test_reconnect_on_fully_released_server_raises(manager):
config = MCPServerConfig(name="gone", transport="stdio", command="echo")
manager.acquire(config)
manager.release("gone")
with pytest.raises(KeyError, match="Unknown MCP server"):
manager.reconnect("gone")
File diff suppressed because it is too large Load Diff
+60
View File
@@ -0,0 +1,60 @@
"""Tests for frontend build fallback in the runner CLI."""
import subprocess
from framework.runner import cli as runner_cli
def _write_frontend_tree(tmp_path, *, with_dist: bool = False):
frontend_dir = tmp_path / "core" / "frontend"
(frontend_dir / "src").mkdir(parents=True)
(frontend_dir / "package.json").write_text("{}", encoding="utf-8")
(frontend_dir / "src" / "main.tsx").write_text("console.log('hi')", encoding="utf-8")
if with_dist:
(frontend_dir / "dist").mkdir()
(frontend_dir / "dist" / "index.html").write_text("<!doctype html>", encoding="utf-8")
return frontend_dir
def test_build_frontend_handles_text_calledprocesserror(monkeypatch, tmp_path, capsys):
_write_frontend_tree(tmp_path)
monkeypatch.chdir(tmp_path)
def fake_run(cmd, **kwargs):
raise subprocess.CalledProcessError(
1,
cmd,
output="npm output",
stderr="vite config failed",
)
monkeypatch.setattr(subprocess, "run", fake_run)
assert runner_cli._build_frontend() is False
output = capsys.readouterr().out
assert "Frontend build failed while running" in output
assert "vite config failed" in output
def test_build_frontend_cleans_cache_and_uses_windows_npm_cmd(monkeypatch, tmp_path):
frontend_dir = _write_frontend_tree(tmp_path)
cache_file = frontend_dir / "tsconfig.app.tsbuildinfo"
cache_file.write_text("stale", encoding="utf-8")
monkeypatch.chdir(tmp_path)
monkeypatch.setattr(runner_cli.sys, "platform", "win32")
commands = []
def fake_run(cmd, **kwargs):
commands.append(cmd)
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
monkeypatch.setattr(subprocess, "run", fake_run)
assert runner_cli._build_frontend() is True
assert not cache_file.exists()
assert commands == [
["npm.cmd", "install", "--no-fund", "--no-audit"],
["npm.cmd", "run", "build"],
]
+142
View File
@@ -0,0 +1,142 @@
"""Tests for AS-9: Skill directory allowlisting in file-read tool interception."""
from unittest.mock import MagicMock
import pytest
from framework.llm.provider import ToolResult
def _make_tool_call_event(tool_name: str, path: str):
"""Build a minimal ToolCallEvent-like object."""
tc = MagicMock()
tc.tool_use_id = "tc-1"
tc.tool_name = tool_name
tc.tool_input = {"path": path}
return tc
def _make_node(skill_dirs: list[str]):
"""Build a minimal EventLoopNode with skill_dirs set."""
from framework.graph.event_loop_node import EventLoopNode
mock_result = ToolResult(tool_use_id="tc-1", content="from-executor")
node = EventLoopNode(tool_executor=MagicMock(return_value=mock_result))
node._skill_dirs = skill_dirs
return node
class TestSkillFileReadInterception:
@pytest.mark.asyncio
async def test_reads_file_in_skill_dir(self, tmp_path):
"""File under a skill dir is read directly, bypassing the executor."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
script = skill_dir / "scripts" / "run.py"
script.parent.mkdir()
script.write_text("print('hello')")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(script))
result = await node._execute_tool(tc)
assert result.content == "print('hello')"
assert not result.is_error
node._tool_executor.assert_not_called()
@pytest.mark.asyncio
async def test_skill_md_read_marked_as_skill_content(self, tmp_path):
"""Reading SKILL.md sets is_skill_content=True for AS-10 protection."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
skill_md = skill_dir / "SKILL.md"
skill_md.write_text("---\nname: my-skill\ndescription: Test\n---\nInstructions.")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(skill_md))
result = await node._execute_tool(tc)
assert result.is_skill_content is True
assert not result.is_error
@pytest.mark.asyncio
async def test_non_skill_md_resource_not_marked(self, tmp_path):
"""Bundled resource (not SKILL.md) is NOT marked as skill_content."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
ref = skill_dir / "references" / "api.md"
ref.parent.mkdir()
ref.write_text("# API Reference")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("load_data", str(ref))
result = await node._execute_tool(tc)
assert result.is_skill_content is False
assert not result.is_error
@pytest.mark.asyncio
async def test_path_outside_skill_dir_goes_to_executor(self, tmp_path):
"""Path outside skill dirs is passed through to the executor unchanged."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
other_file = tmp_path / "other" / "file.txt"
other_file.parent.mkdir()
other_file.write_text("other content")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(other_file))
result = await node._execute_tool(tc)
assert result.content == "from-executor"
node._tool_executor.assert_called_once()
@pytest.mark.asyncio
async def test_no_skill_dirs_goes_to_executor(self, tmp_path):
"""When skill_dirs is empty, all tool calls go to executor."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
script = skill_dir / "scripts" / "run.py"
script.parent.mkdir()
script.write_text("print('hello')")
node = _make_node([])
tc = _make_tool_call_event("view_file", str(script))
result = await node._execute_tool(tc)
assert result.content == "from-executor"
node._tool_executor.assert_called_once()
@pytest.mark.asyncio
async def test_missing_file_returns_error(self, tmp_path):
"""Non-existent file under skill dir returns is_error=True."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
missing = skill_dir / "scripts" / "missing.py"
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(missing))
result = await node._execute_tool(tc)
assert result.is_error is True
assert "Could not read skill resource" in result.content
@pytest.mark.asyncio
async def test_non_file_read_tool_goes_to_executor(self, tmp_path):
"""Non file-read tools (e.g. web_search) bypass the interceptor."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("web_search", str(skill_dir / "SKILL.md"))
result = await node._execute_tool(tc)
assert result.content == "from-executor"
node._tool_executor.assert_called_once()
+8 -1
View File
@@ -69,7 +69,13 @@ class TestSkillCatalog:
def test_to_prompt_xml_generation(self):
skills = [
_make_skill("alpha", "Alpha skill", "project", location="/p/alpha/SKILL.md"),
_make_skill(
"alpha",
"Alpha skill",
"project",
location="/p/alpha/SKILL.md",
base_dir="/p/alpha",
),
_make_skill("beta", "Beta skill", "user", location="/u/beta/SKILL.md"),
]
catalog = SkillCatalog(skills)
@@ -81,6 +87,7 @@ class TestSkillCatalog:
assert "<name>beta</name>" in prompt
assert "<description>Alpha skill</description>" in prompt
assert "<location>/p/alpha/SKILL.md</location>" in prompt
assert "<base_dir>/p/alpha</base_dir>" in prompt
def test_to_prompt_sorted_by_name(self):
skills = [
@@ -0,0 +1,90 @@
"""Tests for AS-10: Activated skill content protected from context pruning."""
import pytest
from framework.graph.conversation import Message, NodeConversation
def _make_conversation() -> NodeConversation:
conv = NodeConversation.__new__(NodeConversation)
conv._messages = []
conv._next_seq = 0
conv._current_phase = None
conv._store = None
return conv
async def _add_tool_msg(conv: NodeConversation, content: str, **kwargs) -> Message:
return await conv.add_tool_result(
tool_use_id=f"tc-{conv._next_seq}",
content=content,
**kwargs,
)
class TestSkillContentProtection:
@pytest.mark.asyncio
async def test_is_skill_content_flag_persists(self):
"""Message created with is_skill_content=True retains the flag."""
conv = _make_conversation()
msg = await _add_tool_msg(conv, "skill instructions", is_skill_content=True)
assert msg.is_skill_content is True
@pytest.mark.asyncio
async def test_regular_message_not_marked(self):
"""Normal tool result messages are not marked as skill content."""
conv = _make_conversation()
msg = await _add_tool_msg(conv, "some tool output")
assert msg.is_skill_content is False
@pytest.mark.asyncio
async def test_skill_content_survives_prune(self):
"""Skill content messages are skipped by prune_old_tool_results."""
conv = _make_conversation()
# Add many regular tool results to push over prune threshold
for _ in range(30):
await _add_tool_msg(conv, "x" * 500) # ~125 tokens each
# Add a skill content message
skill_msg = await _add_tool_msg(
conv,
"## Deep Research\n" + "instructions " * 200,
is_skill_content=True,
)
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
assert pruned > 0, "Expected some messages to be pruned"
# Find the skill message — it must not be pruned
matching = [m for m in conv._messages if m.seq == skill_msg.seq]
assert matching, "Skill content message was removed"
assert not matching[0].content.startswith("[Pruned tool result")
@pytest.mark.asyncio
async def test_regular_content_can_be_pruned(self):
"""Regular tool results are still pruned when over threshold."""
conv = _make_conversation()
for _ in range(20):
await _add_tool_msg(conv, "regular tool output " * 50)
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
assert pruned > 0, "Expected regular messages to be pruned"
@pytest.mark.asyncio
async def test_error_messages_also_protected(self):
"""Existing is_error protection still works alongside is_skill_content."""
conv = _make_conversation()
for _ in range(20):
await _add_tool_msg(conv, "output " * 100)
err_msg = await _add_tool_msg(conv, "tool failed", is_error=True)
await conv.prune_old_tool_results(protect_tokens=200, min_prune_tokens=50)
matching = [m for m in conv._messages if m.seq == err_msg.seq]
assert matching
assert not matching[0].content.startswith("[Pruned tool result")
+151
View File
@@ -0,0 +1,151 @@
"""Tests for skill system structured error codes and diagnostics."""
from __future__ import annotations
import logging
from framework.skills.skill_errors import (
SkillError,
SkillErrorCode,
log_skill_error,
)
class TestSkillErrorCode:
def test_all_codes_defined(self):
codes = {e.value for e in SkillErrorCode}
assert "SKILL_NOT_FOUND" in codes
assert "SKILL_PARSE_ERROR" in codes
assert "SKILL_ACTIVATION_FAILED" in codes
assert "SKILL_MISSING_DESCRIPTION" in codes
assert "SKILL_YAML_FIXUP" in codes
assert "SKILL_NAME_MISMATCH" in codes
assert "SKILL_COLLISION" in codes
class TestSkillError:
def test_code_stored(self):
err = SkillError(
code=SkillErrorCode.SKILL_NOT_FOUND,
what="Skill 'my-skill' not found",
why="Not in catalog",
fix="Check discovery paths",
)
assert err.code == SkillErrorCode.SKILL_NOT_FOUND
def test_message_format(self):
err = SkillError(
code=SkillErrorCode.SKILL_MISSING_DESCRIPTION,
what="Missing description in '/path/SKILL.md'",
why="The description field is absent",
fix="Add a description field to the frontmatter",
)
expected = (
"[SKILL_MISSING_DESCRIPTION]\n"
"What failed: Missing description in '/path/SKILL.md'\n"
"Why: The description field is absent\n"
"Fix: Add a description field to the frontmatter"
)
assert str(err) == expected
def test_is_exception(self):
err = SkillError(
code=SkillErrorCode.SKILL_PARSE_ERROR,
what="Parse failed",
why="Invalid YAML",
fix="Fix the YAML",
)
assert isinstance(err, Exception)
def test_what_why_fix_attributes(self):
err = SkillError(
code=SkillErrorCode.SKILL_COLLISION,
what="Name collision",
why="Two skills share the same name",
fix="Rename one skill directory",
)
assert err.what == "Name collision"
assert err.why == "Two skills share the same name"
assert err.fix == "Rename one skill directory"
class TestLogSkillError:
def test_emits_log(self, caplog):
test_logger = logging.getLogger("test_skill")
with caplog.at_level(logging.ERROR, logger="test_skill"):
log_skill_error(
test_logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what="Invalid SKILL.md at '/path'",
why="Empty file",
fix="Add content",
)
assert "SKILL_PARSE_ERROR" in caplog.text
def test_warning_level(self, caplog):
test_logger = logging.getLogger("test_skill_warn")
with caplog.at_level(logging.WARNING, logger="test_skill_warn"):
log_skill_error(
test_logger,
"warning",
SkillErrorCode.SKILL_YAML_FIXUP,
what="Auto-fixed YAML",
why="Unquoted colons",
fix="Quote values",
)
assert "SKILL_YAML_FIXUP" in caplog.text
def test_message_contains_all_parts(self, caplog):
test_logger = logging.getLogger("test_skill_parts")
with caplog.at_level(logging.ERROR, logger="test_skill_parts"):
log_skill_error(
test_logger,
"error",
SkillErrorCode.SKILL_NOT_FOUND,
what="Skill not found",
why="Not discovered",
fix="Check paths",
)
assert "Skill not found" in caplog.text
assert "Not discovered" in caplog.text
assert "Check paths" in caplog.text
class TestSkillErrorInParser:
def test_missing_description_returns_none(self, tmp_path):
from framework.skills.parser import parse_skill_md
skill_dir = tmp_path / "no-desc"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.\n", encoding="utf-8")
result = parse_skill_md(skill_dir / "SKILL.md")
assert result is None
def test_empty_file_returns_none(self, tmp_path):
from framework.skills.parser import parse_skill_md
skill_dir = tmp_path / "empty"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("", encoding="utf-8")
result = parse_skill_md(skill_dir / "SKILL.md")
assert result is None
def test_nonexistent_returns_none(self, tmp_path):
from framework.skills.parser import parse_skill_md
result = parse_skill_md(tmp_path / "ghost" / "SKILL.md")
assert result is None
def test_yaml_fixup_still_parses(self, tmp_path):
from framework.skills.parser import parse_skill_md
skill_dir = tmp_path / "colon-test"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\nname: colon-test\ndescription: Use for: research\n---\nBody.\n",
encoding="utf-8",
)
result = parse_skill_md(skill_dir / "SKILL.md")
assert result is not None
assert "research" in result.description
+92
View File
@@ -0,0 +1,92 @@
"""Tests for AS-6 skill resource loading support.
Covers:
- <base_dir> element in catalog XML
- allowlisted_dirs property reflects trusted skill base directories
- skill_dirs propagation to NodeContext
"""
from framework.skills.catalog import SkillCatalog
from framework.skills.parser import ParsedSkill
def _make_skill(
name: str,
base_dir: str,
source_scope: str = "project",
) -> ParsedSkill:
return ParsedSkill(
name=name,
description=f"Skill {name}",
location=f"{base_dir}/SKILL.md",
base_dir=base_dir,
source_scope=source_scope,
body="Instructions.",
)
class TestSkillResourceBaseDir:
def test_base_dir_in_xml(self):
"""Each community skill entry should expose its base_dir in the catalog XML."""
skill = _make_skill("deploy", "/project/.hive/skills/deploy")
catalog = SkillCatalog([skill])
prompt = catalog.to_prompt()
assert "<base_dir>/project/.hive/skills/deploy</base_dir>" in prompt
def test_base_dir_xml_escaped(self):
"""base_dir with XML-special chars should be escaped."""
skill = _make_skill("s", "/path/with <&> chars")
catalog = SkillCatalog([skill])
prompt = catalog.to_prompt()
assert "<base_dir>/path/with &lt;&amp;&gt; chars</base_dir>" in prompt
def test_base_dir_absent_for_framework_skills(self):
"""Framework-scope skills are filtered from the catalog, so no base_dir either."""
skill = _make_skill("fw", "/hive/_default_skills/fw", source_scope="framework")
catalog = SkillCatalog([skill])
assert catalog.to_prompt() == ""
def test_allowlisted_dirs_matches_skills(self):
"""allowlisted_dirs returns all skill base_dirs including framework ones."""
skills = [
_make_skill("a", "/skills/a", "project"),
_make_skill("b", "/skills/b", "user"),
_make_skill("c", "/skills/c", "framework"),
]
catalog = SkillCatalog(skills)
dirs = catalog.allowlisted_dirs
assert "/skills/a" in dirs
assert "/skills/b" in dirs
assert "/skills/c" in dirs
def test_allowlisted_dirs_empty_catalog(self):
assert SkillCatalog().allowlisted_dirs == []
class TestSkillDirsPropagation:
def _make_ctx(self, **kwargs):
from unittest.mock import MagicMock
from framework.graph.node import NodeContext
return NodeContext(
runtime=MagicMock(),
node_id="n",
node_spec=MagicMock(),
memory={},
**kwargs,
)
def test_node_context_skill_dirs_default(self):
"""NodeContext.skill_dirs defaults to empty list."""
ctx = self._make_ctx()
assert ctx.skill_dirs == []
def test_node_context_skill_dirs_set(self):
"""NodeContext.skill_dirs can be populated."""
dirs = ["/skills/a", "/skills/b"]
ctx = self._make_ctx(skill_dirs=dirs)
assert ctx.skill_dirs == dirs
+471
View File
@@ -0,0 +1,471 @@
"""Tests for skill trust gating (AS-13)."""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
from framework.skills.parser import ParsedSkill
from framework.skills.trust import (
ProjectTrustClassification,
ProjectTrustDetector,
TrustedRepoStore,
TrustGate,
_is_localhost_remote,
_normalize_remote_url,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_skill(name: str = "test-skill", scope: str = "project") -> ParsedSkill:
return ParsedSkill(
name=name,
description="Test skill",
location=f"/fake/{name}/SKILL.md",
base_dir=f"/fake/{name}",
source_scope=scope,
body="Test skill instructions.",
)
# ---------------------------------------------------------------------------
# _normalize_remote_url
# ---------------------------------------------------------------------------
class TestNormalizeRemoteUrl:
def test_ssh_scp_format(self):
assert _normalize_remote_url("git@github.com:org/repo.git") == "github.com/org/repo"
def test_https_format(self):
assert _normalize_remote_url("https://github.com/org/repo.git") == "github.com/org/repo"
def test_https_no_dot_git(self):
assert _normalize_remote_url("https://github.com/org/repo") == "github.com/org/repo"
def test_ssh_url_format(self):
assert _normalize_remote_url("ssh://git@github.com/org/repo.git") == "github.com/org/repo"
def test_lowercased(self):
assert _normalize_remote_url("git@GitHub.COM:Org/Repo.git") == "github.com/org/repo"
def test_trailing_slash_stripped(self):
assert _normalize_remote_url("https://github.com/org/repo/") == "github.com/org/repo"
def test_gitlab(self):
assert _normalize_remote_url("git@gitlab.com:team/project.git") == "gitlab.com/team/project"
# ---------------------------------------------------------------------------
# _is_localhost_remote
# ---------------------------------------------------------------------------
class TestIsLocalhostRemote:
def test_localhost_https(self):
assert _is_localhost_remote("http://localhost/org/repo")
def test_127_0_0_1(self):
assert _is_localhost_remote("https://127.0.0.1/repo")
def test_github_not_local(self):
assert not _is_localhost_remote("https://github.com/org/repo")
def test_scp_localhost(self):
assert _is_localhost_remote("git@localhost:org/repo")
# ---------------------------------------------------------------------------
# TrustedRepoStore
# ---------------------------------------------------------------------------
class TestTrustedRepoStore:
def test_empty_store_is_not_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
assert not store.is_trusted("github.com/org/repo")
def test_trust_and_lookup(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
store.trust("github.com/org/repo", project_path="/some/path")
assert store.is_trusted("github.com/org/repo")
def test_revoke(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
store.trust("github.com/org/repo")
assert store.revoke("github.com/org/repo")
assert not store.is_trusted("github.com/org/repo")
def test_revoke_nonexistent_returns_false(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
assert not store.revoke("github.com/nobody/nowhere")
def test_persists_across_instances(self, tmp_path):
path = tmp_path / "trusted.json"
store1 = TrustedRepoStore(path)
store1.trust("github.com/org/repo")
store2 = TrustedRepoStore(path)
assert store2.is_trusted("github.com/org/repo")
def test_atomic_write(self, tmp_path):
"""Save must not leave a .tmp file behind."""
path = tmp_path / "trusted.json"
store = TrustedRepoStore(path)
store.trust("github.com/org/repo")
assert not (tmp_path / "trusted.tmp").exists()
assert path.exists()
def test_corrupted_json_recovers_gracefully(self, tmp_path):
path = tmp_path / "trusted.json"
path.write_text("{not valid json{{", encoding="utf-8")
store = TrustedRepoStore(path)
assert not store.is_trusted("github.com/any/repo") # no crash
def test_json_schema(self, tmp_path):
path = tmp_path / "trusted.json"
store = TrustedRepoStore(path)
store.trust("github.com/org/repo", project_path="/work/repo")
data = json.loads(path.read_text())
assert data["version"] == 1
assert data["entries"][0]["repo_key"] == "github.com/org/repo"
assert "added_at" in data["entries"][0]
def test_list_entries(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
store.trust("github.com/a/b")
store.trust("github.com/c/d")
entries = store.list_entries()
assert len(entries) == 2
# ---------------------------------------------------------------------------
# ProjectTrustDetector
# ---------------------------------------------------------------------------
class TestProjectTrustDetector:
def test_none_project_dir_always_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
cls, _ = det.classify(None)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_nonexistent_dir_always_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
cls, _ = det.classify(tmp_path / "nonexistent")
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_no_git_dir_always_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_no_remote_always_trusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
# git command returns non-zero (no remote)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=1, stdout="")
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_localhost_remote_always_trusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="http://localhost/org/repo.git\n"
)
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_trusted_by_store(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
store.trust("github.com/trusted/repo")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="git@github.com:trusted/repo.git\n"
)
cls, key = det.classify(tmp_path)
assert cls == ProjectTrustClassification.TRUSTED_BY_USER
assert key == "github.com/trusted/repo"
def test_unknown_remote_untrusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
cls, key = det.classify(tmp_path)
assert cls == ProjectTrustClassification.UNTRUSTED
assert key == "github.com/stranger/repo"
def test_own_remotes_env_var(self, tmp_path, monkeypatch):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
monkeypatch.setenv("HIVE_OWN_REMOTES", "github.com/myorg/*")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="git@github.com:myorg/myrepo.git\n"
)
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_git_timeout_treated_as_trusted(self, tmp_path):
import subprocess
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run", side_effect=subprocess.TimeoutExpired("git", 3)):
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_git_not_found_treated_as_trusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run", side_effect=FileNotFoundError("git not found")):
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
# ---------------------------------------------------------------------------
# TrustGate
# ---------------------------------------------------------------------------
class TestTrustGate:
def test_framework_scope_always_passes(self, tmp_path):
skill = make_skill("fw-skill", "framework")
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
result = gate.filter_and_gate([skill], project_dir=None)
assert any(s.name == "fw-skill" for s in result)
def test_user_scope_always_passes(self, tmp_path):
skill = make_skill("user-skill", "user")
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
result = gate.filter_and_gate([skill], project_dir=None)
assert any(s.name == "user-skill" for s in result)
def test_no_project_skills_returns_early(self, tmp_path):
"""When there are no project-scope skills, trust detection is skipped."""
fw = make_skill("fw", "framework")
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
result = gate.filter_and_gate([fw], project_dir=tmp_path)
assert result == [fw]
def test_trusted_project_skills_pass(self, tmp_path):
"""Project skills from a trusted repo pass through."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
store.trust("github.com/trusted/repo")
skill = make_skill("proj-skill", "project")
gate = TrustGate(store=store, interactive=False)
with patch("subprocess.run") as m:
m.return_value = MagicMock(returncode=0, stdout="git@github.com:trusted/repo.git\n")
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "proj-skill" for s in result)
def test_untrusted_headless_skips_and_logs(self, tmp_path, caplog):
"""In non-interactive mode, untrusted project skills are skipped."""
import logging
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("evil-skill", "project")
gate = TrustGate(store=store, interactive=False)
with patch("subprocess.run") as m:
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/evil.git\n"
)
with caplog.at_level(logging.WARNING):
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert not any(s.name == "evil-skill" for s in result)
assert "untrusted" in caplog.text.lower() or "skipping" in caplog.text.lower()
def test_interactive_consent_session_only(self, tmp_path):
"""Option 1 (session only) includes skills without writing to store."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("session-skill", "project")
outputs = []
gate = TrustGate(
store=store,
interactive=True,
print_fn=outputs.append,
input_fn=lambda _: "1", # trust this session
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "session-skill" for s in result)
# Must NOT persist to trusted store
assert not store.is_trusted("github.com/stranger/repo")
def test_interactive_consent_permanent(self, tmp_path):
"""Option 2 (permanent) includes skills and persists to trusted store."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("perm-skill", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: "2", # trust permanently
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "perm-skill" for s in result)
assert store.is_trusted("github.com/stranger/repo")
def test_interactive_consent_deny(self, tmp_path):
"""Option 3 (deny) excludes project skills."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("bad-skill", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: "3", # deny
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert not any(s.name == "bad-skill" for s in result)
def test_env_var_override_trusts_all(self, tmp_path, monkeypatch):
"""HIVE_TRUST_PROJECT_SKILLS=1 bypasses gating entirely."""
monkeypatch.setenv("HIVE_TRUST_PROJECT_SKILLS", "1")
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("env-skill", "project")
gate = TrustGate(store=store, interactive=False)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "env-skill" for s in result)
def test_keyboard_interrupt_treated_as_deny(self, tmp_path):
"""Ctrl-C during consent prompt should deny cleanly."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("interrupted-skill", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: (_ for _ in ()).throw(KeyboardInterrupt()),
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert not any(s.name == "interrupted-skill" for s in result)
def test_security_notice_shown_once(self, tmp_path, monkeypatch):
"""Security notice (NFR-5) should be shown the first time only."""
# Use a temp sentinel path
sentinel = tmp_path / ".skill_trust_notice_shown"
monkeypatch.setattr("framework.skills.trust._NOTICE_SENTINEL_PATH", sentinel)
assert not sentinel.exists()
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("notice-skill", "project")
output_lines: list[str] = []
gate = TrustGate(
store=store,
interactive=True,
print_fn=output_lines.append,
input_fn=lambda _: "3",
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
gate.filter_and_gate([skill], project_dir=tmp_path)
assert sentinel.exists()
assert any("Security notice" in line for line in output_lines)
# Second run should NOT show the notice again
output_lines.clear()
skill2 = make_skill("notice-skill-2", "project")
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
gate.filter_and_gate([skill2], project_dir=tmp_path)
assert not any("Security notice" in line for line in output_lines)
def test_mixed_scopes_only_project_gated(self, tmp_path, monkeypatch):
"""Framework and user skills should pass through even if project skills are denied."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
fw_skill = make_skill("fw", "framework")
user_skill = make_skill("usr", "user")
proj_skill = make_skill("proj", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: "3", # deny project skills
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([fw_skill, user_skill, proj_skill], project_dir=tmp_path)
names = {s.name for s in result}
assert "fw" in names
assert "usr" in names
assert "proj" not in names
+302
View File
@@ -6,9 +6,12 @@ ToolResult instances. Historically, invalid JSON in ToolResult.content
could cause a json.JSONDecodeError and crash execution.
"""
import logging
import textwrap
from pathlib import Path
from types import SimpleNamespace
from framework.llm.provider import Tool, ToolUse
from framework.runner.tool_registry import ToolRegistry
@@ -91,3 +94,302 @@ def test_discover_from_module_handles_empty_content(tmp_path):
result = registered.executor({})
assert isinstance(result, dict)
assert result == {}
class _RegistryFakeClient:
def __init__(self, config):
self.config = config
self.connect_calls = 0
self.disconnect_calls = 0
def connect(self) -> None:
self.connect_calls += 1
def disconnect(self) -> None:
self.disconnect_calls += 1
def list_tools(self):
return [
SimpleNamespace(
name="pooled_tool",
description="Tool from MCP",
input_schema={"type": "object", "properties": {}, "required": []},
)
]
def call_tool(self, tool_name, arguments):
return [{"text": f"{tool_name}:{arguments}"}]
def test_register_mcp_server_uses_connection_manager_when_enabled(monkeypatch):
registry = ToolRegistry()
client = _RegistryFakeClient(SimpleNamespace(name="shared"))
manager_calls: list[tuple[str, str]] = []
class FakeManager:
def acquire(self, config):
manager_calls.append(("acquire", config.name))
client.config = config
return client
def release(self, server_name: str) -> None:
manager_calls.append(("release", server_name))
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
count = registry.register_mcp_server(
{"name": "shared", "transport": "stdio", "command": "echo"},
use_connection_manager=True,
)
assert count == 1
assert manager_calls == [("acquire", "shared")]
registry.cleanup()
assert manager_calls == [("acquire", "shared"), ("release", "shared")]
assert client.disconnect_calls == 0
def test_register_mcp_server_defaults_to_connection_manager(monkeypatch):
"""Default behavior uses the connection manager (reuse enabled by default)."""
registry = ToolRegistry()
created_clients: list[_RegistryFakeClient] = []
def fake_client_factory(config):
client = _RegistryFakeClient(config)
created_clients.append(client)
return client
class FakeManager:
def acquire(self, config):
return fake_client_factory(config)
def release(self, server_name):
pass
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
count = registry.register_mcp_server(
{"name": "direct", "transport": "stdio", "command": "echo"},
)
assert count == 1
assert len(created_clients) == 1
def test_register_mcp_server_direct_client_when_manager_disabled(monkeypatch):
"""When use_connection_manager=False, a direct MCPClient is created."""
registry = ToolRegistry()
created_clients: list[_RegistryFakeClient] = []
def fake_client_factory(config):
client = _RegistryFakeClient(config)
created_clients.append(client)
return client
monkeypatch.setattr("framework.runner.mcp_client.MCPClient", fake_client_factory)
count = registry.register_mcp_server(
{"name": "direct", "transport": "stdio", "command": "echo"},
use_connection_manager=False,
)
assert count == 1
assert len(created_clients) == 1
assert created_clients[0].connect_calls == 1
registry.cleanup()
assert created_clients[0].disconnect_calls == 1
def test_load_registry_servers_retries_when_registration_returns_zero(monkeypatch):
registry = ToolRegistry()
attempts = {"count": 0}
def fake_register(server_config, use_connection_manager=True):
attempts["count"] += 1
return 0 if attempts["count"] == 1 else 2
monkeypatch.setattr(registry, "register_mcp_server", fake_register)
monkeypatch.setattr("time.sleep", lambda _: None)
results = registry.load_registry_servers(
[{"name": "jira", "transport": "http", "url": "http://localhost:4010"}],
log_summary=False,
)
assert attempts["count"] == 2
assert results == [
{
"server": "jira",
"status": "loaded",
"tools_loaded": 2,
"skipped_reason": None,
}
]
def test_load_registry_servers_marks_failures_as_skipped(monkeypatch):
registry = ToolRegistry()
monkeypatch.setattr(registry, "register_mcp_server", lambda *args, **kwargs: 0)
monkeypatch.setattr("time.sleep", lambda _: None)
results = registry.load_registry_servers(
[{"name": "jira", "transport": "http", "url": "http://localhost:4010"}],
log_summary=False,
)
assert results == [
{
"server": "jira",
"status": "skipped",
"tools_loaded": 0,
"skipped_reason": "registered 0 tools",
}
]
def test_load_registry_servers_emits_structured_log_fields(monkeypatch):
registry = ToolRegistry()
captured_logs: list[tuple[str, dict | None]] = []
monkeypatch.setattr(registry, "register_mcp_server", lambda *args, **kwargs: 2)
monkeypatch.setattr(
"framework.runner.tool_registry.logger.info",
lambda message, *args, **kwargs: captured_logs.append((message, kwargs.get("extra"))),
)
registry.load_registry_servers(
[{"name": "jira", "transport": "http", "url": "http://localhost:4010"}],
log_summary=True,
)
assert captured_logs == [
(
"MCP registry server resolution",
{
"event": "mcp_registry_server_resolution",
"server": "jira",
"status": "loaded",
"tools_loaded": 2,
"skipped_reason": None,
},
)
]
def test_tool_execution_error_logs_stack_trace_and_context(caplog):
"""ToolRegistry should log stack traces and context when tool execution fails."""
registry = ToolRegistry()
def failing_executor(inputs: dict) -> None:
raise ValueError("Intentional test failure")
tool = Tool(
name="failing_tool",
description="A tool that always fails",
parameters={"type": "object", "properties": {}},
)
registry.register("failing_tool", tool, failing_executor)
tool_use = ToolUse(
id="test_call_123",
name="failing_tool",
input={"param": "value"},
)
with caplog.at_level(logging.ERROR):
executor = registry.get_executor()
result = executor(tool_use)
assert result.is_error is True
assert "Intentional test failure" in result.content
assert any("failing_tool" in record.message for record in caplog.records)
assert any("test_call_123" in record.message for record in caplog.records)
assert any(record.exc_info is not None for record in caplog.records)
def test_tool_execution_error_logs_inputs(caplog):
"""ToolRegistry should log tool inputs when execution fails."""
registry = ToolRegistry()
def failing_executor(inputs: dict) -> None:
raise RuntimeError("Tool failed")
tool = Tool(
name="input_logging_tool",
description="Tests input logging",
parameters={"type": "object", "properties": {"foo": {"type": "string"}}},
)
registry.register("input_logging_tool", tool, failing_executor)
tool_use = ToolUse(
id="call_456",
name="input_logging_tool",
input={"foo": "bar", "nested": {"key": "value"}},
)
with caplog.at_level(logging.ERROR):
executor = registry.get_executor()
executor(tool_use)
log_messages = [record.message for record in caplog.records]
full_log = " ".join(log_messages)
assert '"foo": "bar"' in full_log or "'foo': 'bar'" in full_log
def test_unknown_tool_error_returns_proper_result():
"""ToolRegistry should return proper error for unknown tools."""
registry = ToolRegistry()
tool_use = ToolUse(
id="unknown_call",
name="nonexistent_tool",
input={},
)
executor = registry.get_executor()
result = executor(tool_use)
assert result.is_error is True
assert "Unknown tool" in result.content
assert "nonexistent_tool" in result.content
def test_tool_execution_error_truncates_large_inputs(caplog):
"""ToolRegistry should truncate large inputs in error logs."""
registry = ToolRegistry()
def failing_executor(inputs: dict) -> None:
raise RuntimeError("Tool failed")
tool = Tool(
name="large_input_tool",
description="Tests input truncation",
parameters={"type": "object", "properties": {}},
)
registry.register("large_input_tool", tool, failing_executor)
large_input = {"data": "x" * 1000}
tool_use = ToolUse(
id="call_789",
name="large_input_tool",
input=large_input,
)
with caplog.at_level(logging.ERROR):
executor = registry.get_executor()
executor(tool_use)
log_messages = [record.message for record in caplog.records]
full_log = " ".join(log_messages)
assert "...(truncated)" in full_log
+4 -4
View File
@@ -157,7 +157,7 @@ All bounty types open in parallel. Contributors self-select. Daily progress upda
PR merged with bounty:* label
→ GitHub Action runs bounty-tracker.ts
→ Calculates points from label
→ Resolves GitHub → Discord ID via contributors.yml
→ Resolves GitHub → Discord ID via MongoDB (hive.contributors)
→ Pushes XP to Lurkr API
→ Posts notification to #integrations-announcements
```
@@ -166,7 +166,7 @@ See the [Setup Guide](setup-guide.md) for full configuration (Lurkr, webhooks, s
### Identity Linking
Contributors link GitHub ↔ Discord by opening a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue. A GitHub Action auto-adds them to `contributors.yml` and closes the issue.
Contributors link GitHub ↔ Discord by running `/link-github` in Discord. The bot verifies ownership via a public gist, then stores the mapping in MongoDB.
Without this link, bounties are still tracked but Lurkr can't push XP to your Discord account.
@@ -181,7 +181,7 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
| Agent Builder role | Lurkr bot | Auto-assigned at level 5 |
| OSS Contributor role | Lurkr bot | Auto-assigned at level 15 |
| Core Contributor role | Maintainer | Manual (involves money) |
| Identity linking | contributors.yml | PR-based, reviewed by maintainers |
| Identity linking | Discord bot → MongoDB | `/link-github` command with gist verification |
## Guides
@@ -203,4 +203,4 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
- `.github/workflows/weekly-leaderboard.yml` — Monday leaderboard post
- `scripts/bounty-tracker.ts` — Point calculation, Lurkr API, Discord formatting
- `scripts/setup-bounty-labels.sh` — One-time label setup
- `contributors.yml` — GitHub ↔ Discord identity mapping
- MongoDB `hive.contributors` — GitHub ↔ Discord identity mapping (managed by Discord bot)
+2 -4
View File
@@ -6,9 +6,7 @@ Earn XP, Discord roles, and eventually real money by contributing to the Aden ag
### 1. Link your GitHub and Discord
Open a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue — just paste your Discord ID and submit. A GitHub Action will automatically add you to `contributors.yml` and close the issue.
To find your Discord ID: Discord Settings > Advanced > Enable **Developer Mode**, then right-click your name > **Copy User ID**.
Run `/link-github your-github-username` in Discord. The bot will give you a verification code — create a public gist with that code, then run `/verify`. Done.
Without this link, Lurkr can't push XP to your Discord account.
@@ -154,7 +152,7 @@ A: Yes. Most services have free tiers. The bounty issue links to where you get t
A: Contribute consistently across different bounty types for 4+ weeks. Maintainers will nominate you.
**Q: What if I haven't linked my Discord yet?**
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Add yourself to `contributors.yml`.
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Run `/link-github` in Discord.
## Quick Reference
+4 -2
View File
@@ -104,6 +104,8 @@ Repo Settings > Secrets and variables > Actions:
| `DISCORD_BOUNTY_WEBHOOK_URL` | Webhook URL from Step 5 |
| `LURKR_API_KEY` | Lurkr API key from Step 4f |
| `LURKR_GUILD_ID` | Your Discord server ID\* |
| `BOT_API_URL` | Discord bot API URL |
| `BOT_API_KEY` | Discord bot API key |
\*Enable Developer Mode in Discord, right-click server name > Copy Server ID.
@@ -146,12 +148,12 @@ powerbi, redis
- [ ] All 3 GitHub secrets added
- [ ] Both workflows enabled (`bounty-completed.yml`, `weekly-leaderboard.yml`)
- [ ] Test PR + merge triggers Discord notification
- [ ] `contributors.yml` exists at repo root
- [ ] MongoDB `hive.contributors` collection accessible
## Troubleshooting
**No Discord message:** Check `DISCORD_BOUNTY_WEBHOOK_URL` secret and Action logs.
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor is in `contributors.yml`, check Action logs for `Lurkr XP push failed`.
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor has run `/link-github` in Discord, check Action logs for `Lurkr XP push failed`.
**Role not assigned:** Verify role rewards in the Lurkr dashboard or via `/config set`. Lurkr's role must be above the roles it assigns in server hierarchy.
+53 -1
View File
@@ -41,6 +41,12 @@ export ANTHROPIC_API_KEY="sk-ant-..."
# OpenAI (optional, for GPT models via LiteLLM)
export OPENAI_API_KEY="sk-..."
# OpenRouter (optional, for OpenRouter-hosted models)
export OPENROUTER_API_KEY="..."
# Hive LLM (optional, for Hive-managed models)
export HIVE_API_KEY="..."
# Cerebras (optional, used by output cleaner and some nodes)
export CEREBRAS_API_KEY="..."
@@ -50,6 +56,49 @@ export GROQ_API_KEY="..."
The framework supports 100+ LLM providers through [LiteLLM](https://docs.litellm.ai/docs/providers). Set the corresponding environment variable for your provider.
### Provider Examples
OpenRouter:
```json
{
"llm": {
"provider": "openrouter",
"model": "x-ai/grok-4.20-beta",
"max_tokens": 8192,
"api_key_env_var": "OPENROUTER_API_KEY",
"api_base": "https://openrouter.ai/api/v1"
}
}
```
Notes:
- Set `provider` to `openrouter`
- Use the raw OpenRouter model ID in `model`, for example `x-ai/grok-4.20-beta`
- `api_base` should be `https://openrouter.ai/api/v1`
- If you paste a model that already starts with `openrouter/`, Hive tolerates and normalizes it
Hive LLM:
```json
{
"llm": {
"provider": "hive",
"model": "queen",
"max_tokens": 32768,
"api_key_env_var": "HIVE_API_KEY",
"api_base": "https://api.adenhq.com"
}
}
```
Notes:
- Set `provider` to `hive`
- Common Hive model values are `queen`, `kimi-2.5`, and `GLM-5`
- Hive LLM requests use the Hive endpoint at `https://api.adenhq.com`
### Search & Tools (optional)
```bash
@@ -191,13 +240,16 @@ cd core && uv pip install -e .
Ensure the environment variable is set in your current shell session:
```bash
echo $ANTHROPIC_API_KEY # Should print your key
echo $ANTHROPIC_API_KEY # Or echo $OPENROUTER_API_KEY / echo $HIVE_API_KEY
```
On Windows PowerShell:
```powershell
$env:ANTHROPIC_API_KEY = "sk-ant-..."
# Or:
$env:OPENROUTER_API_KEY = "your-openrouter-key"
$env:HIVE_API_KEY = "your-hive-key"
```
### Agent not found
-981
View File
@@ -1,981 +0,0 @@
# Credential Identity & Multi-Account Foundation (Issue #4755)
## Context
Agents are identity-blind. When `gmail_read_email` runs, neither the LLM nor the tool
knows whose inbox it's operating on. One `ADEN_API_KEY` can back N accounts of the same
provider (e.g., 10 Gmail accounts), but today the system can only surface one — the last
one synced silently overwrites all others.
This plan traces the **5-tuple relationship** (Agent Definition → Agent Instance →
Agent Tool → Auth Provider → Auth User Identity) through every layer of the stack,
identifies exactly where things break, and prescribes targeted fixes.
### Motivating Scenarios
**Scenario A — Executive Assistant Agent**: A company deploys an agent that manages
calendars for 5 executives. Each executive has connected their Google account through
Aden. The agent's job is to check each person's availability and schedule meetings.
Today: the agent can only see ONE person's calendar (whichever synced last). The other
4 accounts are silently lost in the index collision. The agent schedules meetings on
the wrong person's calendar with no indication anything is wrong.
**Scenario B — Multi-Channel Support Agent**: A support team agent is connected to
3 Slack workspaces (Engineering, Sales, Support), a shared Gmail inbox, and a personal
Gmail for the team lead. Today: the agent sees one Slack workspace, one Gmail. It
cannot tell which Slack workspace it's posting to or whose Gmail it's reading. It
might reply to a customer email from the team lead's personal inbox.
**Scenario C — Compliance & Audit**: An enterprise client requires audit logs showing
which account was accessed, when, and by which agent. Today: the system logs
`credentials.get("google")` — no record of which of the 10 Google accounts was used.
Impossible to audit.
**Scenario D — Single-Account Agent (backward compat)**: A simple agent uses one
Gmail account and one Slack bot. Nothing should change. `credentials.get("google")`
returns the same token it always did. Zero migration, zero configuration changes.
---
## The 5-Tuple Model
Every credential interaction involves five entities. Understanding how they relate
(and where the relationships break) is the key to the fix.
```
Agent Definition ──→ Agent Instance ──→ Agent Tool ──→ Auth Provider ──→ Auth User Identity
"I need Gmail" "Here's your "Give me a "Here's one "Whose token
Gmail tool" token" token" is this?"
← MISSING
```
### 1. Agent Definition (what tools are needed)
**Files**: `exports/{name}/agent.py`, `nodes/__init__.py`, `mcp_servers.json`
An exported agent declares `NodeSpec.tools = ["gmail_read_email", "gmail_send_email"]`.
The `mcp_servers.json` points to the tools MCP server. The agent definition has NO
credential awareness — it names tools, not credentials. This is intentional: the same
agent definition can run against different credential sets in different environments
(dev vs. prod, tenant A vs. tenant B).
**Business logic**: Agent definitions are portable templates. A "Gmail Triage" agent
built by one team can be deployed to 50 different customers, each with their own
Google accounts. The agent definition never hard-codes credential IDs.
**Status**: Fine. No changes needed.
### 2. Agent Instance (runtime wiring)
**Files**: `runner.py`, `tool_registry.py`, `mcp_client.py`
`AgentRunner.__init__()` does three things in sequence:
1. `validate_agent_credentials(graph.nodes)` — checks presence + health
2. `ToolRegistry.load_mcp_config()``MCPClient` spawns subprocess
3. `_setup()``create_agent_runtime()` with discovered tools
The `ToolRegistry` bridges parent ↔ MCP subprocess:
- `CONTEXT_PARAMS = {"workspace_id", "agent_id", "session_id", "data_dir"}` — stripped
from LLM schema, injected at call time via `make_mcp_executor` closure
- `set_session_context()` — set once at startup
- `set_execution_context()` — per-execution via `contextvars`
The MCP subprocess inherits `os.environ` at spawn time via
`merged_env = {**os.environ, **(config.env or {})}` in `mcp_client.py:157`.
**Business logic**: The agent instance is where "portable template" meets "specific
deployment." An instance knows which Aden API key to use, which workspace it belongs
to, which tools are available. The `CONTEXT_PARAMS` mechanism is how the framework
passes deployment-specific context into tools without the LLM knowing or caring.
This is the natural extension point for `account` routing in the future.
**Scenario**: Two customers both deploy the same "Email Triage" agent. Customer A
has 2 Google accounts; Customer B has 5. Each customer's `AgentRunner` validates
against their own Aden key, discovers different sets of credentials, and wires them
into the same agent graph. The agent definition is identical.
**Status**: Works for single-account. The `CONTEXT_PARAMS` pattern is the right
mechanism for future multi-account routing (adding `account` param).
### 3. Agent Tool (credential consumption)
**Files**: `tools/src/aden_tools/tools/*/`, `tools/mcp_server.py`
Every tool follows the same pattern:
```python
def register_gmail_tools(mcp, credentials=None):
def _get_token():
if credentials is not None:
return credentials.get("google") # ← single token, identity unknown
return os.getenv("GOOGLE_ACCESS_TOKEN")
@mcp.tool()
def gmail_read_email(message_id: str):
token = _get_token()
...
```
The `credentials` object is `CredentialStoreAdapter`, created once at MCP server startup
via `CredentialStoreAdapter.default()`. All tool closures capture this single shared
instance.
**Business logic**: Tools are the consumer endpoint — they need a valid access token
to call external APIs. They don't care about Aden, sync, or storage. They just need
`_get_token()` to return the right token. Today, "right" is undefined because there's
no way to say "the token for alice@company.com, not bob@company.com."
**Where it breaks — Scenario A revisited**: The executive assistant agent calls
`gmail_read_email()` intending to read Alice's inbox. `_get_token()` returns
`credentials.get("google")` which resolves to... Bob's token (he synced last).
The agent reads Bob's emails, thinks they're Alice's, and schedules meetings
accordingly. No error is raised. No indication anything is wrong. The agent is
confidently operating on the wrong person's data.
**Where it breaks — Scenario B revisited**: The support agent calls
`slack_post_message(channel="support-tickets")`. It uses a Slack token from
the Engineering workspace (last synced). The message goes to a channel that
doesn't exist in Engineering, returns an error, and the agent retries in a loop
with no understanding of why it's failing.
### 4. Auth Provider (credential storage & resolution)
**Files**: `store.py`, `aden/storage.py`, `aden/provider.py`, `aden/client.py`
Resolution chain:
```
credentials.get("google")
→ CredentialStoreAdapter.get("google")
→ CredentialStore.get("google")
→ AdenCachedStorage.load("google")
→ _provider_index.get("google") → "google_def456" (last write wins)
→ _load_by_id("google_def456")
→ Returns ONE CredentialObject
```
**The index collision bug** (`storage.py:303`):
```python
def _index_provider(self, credential):
provider_name = integration_type_key.value.get_secret_value()
self._provider_index[provider_name] = credential.id # ← OVERWRITES
```
**Business logic**: The storage layer is responsible for mapping human-readable
provider names ("google") to internal hash-based credential IDs ("google_abc123").
This mapping is essential because Aden generates unique hash IDs per connected account,
but tools reference providers by name. The `_provider_index` is this mapping.
**Why it's a `dict[str, str]` today**: The original design assumed 1:1 between
provider name and credential. "One Google account per API key." This was valid
for simple deployments but breaks fundamentally when an Aden API key backs multiple
accounts of the same provider.
**The collision mechanics**: When `sync_all()` runs, it iterates over all active
integrations from Aden. For a user with 3 Gmail accounts:
1. Sync `google_abc123` (alice@co.com) → `_provider_index["google"] = "google_abc123"`
2. Sync `google_def456` (bob@co.com) → `_provider_index["google"] = "google_def456"` ← Alice lost
3. Sync `google_ghi789` (carol@co.com) → `_provider_index["google"] = "google_ghi789"` ← Bob lost
All three `.enc` files exist on disk. Only Carol's is reachable by name. Alice's and
Bob's tokens are orphaned — encrypted, on disk, but invisible to the resolution chain.
**Why the disk layer is fine**: `EncryptedFileStorage` uses the hash ID as filename:
`google_abc123.enc`, `google_def456.enc`. No collision. The problem is purely in the
in-memory index that maps names to IDs.
### 5. Auth User Identity (THE MISSING PIECE)
**Files**: `models.py` (no identity model), `aden/provider.py` (metadata discarded),
`health_check.py` (identity parsed then discarded), `validation.py` (details ignored)
**Business logic**: Identity answers "whose account is this?" Every external service
provides identity data in its API responses — Gmail returns `emailAddress`, GitHub
returns `login`, Slack returns `team` + `user`. This data already flows through the
system during health checks and Aden syncs. It's parsed, briefly held in local
variables, and then discarded. No model captures it. No property exposes it. No
downstream consumer reads it.
Identity data exists at two sources but is discarded:
| Source | Data Available | What Happens |
|--------|---------------|--------------|
| Aden `metadata.email` | Email of connected account | `_aden_response_to_credential()` ignores `metadata` dict |
| Gmail health check | `emailAddress` field | `OAuthBearerHealthChecker.check()` returns `valid=True`, discards response body |
| GitHub health check | `login` username | Parsed to `details["username"]`, validation ignores `details` |
| Slack health check | `team`, `user` | Parsed to `details`, validation ignores `details` |
| Discord health check | `username`, `id` | Parsed to `details`, validation ignores `details` |
| Calendar health check | Primary calendar `id` = email | `OAuthBearerHealthChecker.check()` discards response body |
**The waste**: Every agent startup already makes these health check API calls. The
identity data is RIGHT THERE in the response body. We parse it for validation logic,
then throw it away. Zero additional API calls needed — we just need to keep what we
already have.
**What identity enables downstream**:
- LLM knows whose inbox it's reading (system prompt awareness)
- Tools can route to specific accounts (future `account` parameter)
- Audit logs can record which identity was accessed
- Users can see which accounts are connected in TUI/dashboard
- Agents can reason about cross-account operations ("forward from alice to bob")
---
## What Changes — Layer by Layer
### Step 1: `CredentialIdentity` model on `CredentialObject`
**File**: `core/framework/credentials/models.py`
**Business logic**: Every credential needs a structured way to answer "who does this
belong to?" Different providers express identity differently:
| Provider | Primary Identity | Secondary Identity |
|----------|-----------------|-------------------|
| Google (Gmail, Calendar, Drive) | Email address | — |
| Slack | Workspace name | Bot username |
| GitHub | Username (login) | — |
| Discord | Username | Account ID |
| HubSpot | Portal ID | — |
| Microsoft 365 | Email address | Tenant ID |
The `CredentialIdentity` model normalizes these into four universal fields:
`email`, `username`, `workspace`, `account_id`. The `label` property picks the
best human-readable identifier for display (email preferred, then username, etc.).
**Why a computed property, not a stored field**: Identity is derived from
`_identity_*` keys that already exist in the credential's key vault. Storing it
as a separate field would create a sync problem (what if keys update but the field
doesn't?). A computed property always reflects current state.
**Scenarios this enables**:
- **Display**: `cred.identity.label``"alice@company.com"` (for system prompts, TUI, logs)
- **Comparison**: `cred.identity.email == "alice@company.com"` (for account routing)
- **Serialization**: `cred.identity.to_dict()``{"email": "alice@company.com"}` (for MCP tool responses)
- **Existence check**: `cred.identity.is_known``True` (skip accounts with no identity)
- **Provider type**: `cred.provider_type``"google"` (from `_integration_type` key)
**Key design decision**: `set_identity(**fields)` persists as `_identity_*` keys using
the existing `set_key()` method. This means identity survives serialization/deserialization
through `EncryptedFileStorage` without any schema migration. Old credentials without
identity keys simply return `CredentialIdentity()` with all `None` fields and
`label == "unknown"`.
```python
class CredentialIdentity(BaseModel):
email: str | None = None
username: str | None = None
workspace: str | None = None
account_id: str | None = None
@property
def label(self) -> str:
return self.email or self.username or self.workspace or self.account_id or "unknown"
@property
def is_known(self) -> bool:
return bool(self.email or self.username or self.workspace or self.account_id)
def to_dict(self) -> dict[str, str]:
return {k: v for k, v in self.model_dump().items() if v is not None}
```
On `CredentialObject`:
```python
@property
def identity(self) -> CredentialIdentity:
fields = {}
for key_name, key_obj in self.keys.items():
if key_name.startswith("_identity_"):
field = key_name[len("_identity_"):]
fields[field] = key_obj.value.get_secret_value()
return CredentialIdentity(**{k: v for k, v in fields.items()
if k in CredentialIdentity.model_fields})
@property
def provider_type(self) -> str | None:
key = self.keys.get("_integration_type")
return key.value.get_secret_value() if key else None
def set_identity(self, **fields: str) -> None:
for field_name, value in fields.items():
if value:
self.set_key(f"_identity_{field_name}", value)
```
---
### Step 2: Fix storage multi-account index
**File**: `core/framework/credentials/aden/storage.py`
**Business logic**: The core bug. When a user connects multiple accounts of the same
provider type through Aden, all but the last one becomes unreachable. This affects
every multi-account deployment silently — no error, no warning, just missing accounts.
**`_provider_index`**: `dict[str, str]``dict[str, list[str]]`
**Before (broken)**:
```
sync google_abc123 (alice) → index["google"] = "google_abc123"
sync google_def456 (bob) → index["google"] = "google_def456" ← alice lost
load("google") → returns bob's token
```
**After (fixed)**:
```
sync google_abc123 (alice) → index["google"] = ["google_abc123"]
sync google_def456 (bob) → index["google"] = ["google_abc123", "google_def456"]
load("google") → returns alice's token (first = backward compat)
load_all_for_provider("google") → returns [alice, bob]
```
**Backward compatibility contract**: Every existing tool calls `credentials.get("google")`
and expects a single token string back. This MUST continue to work. `load("google")`
returns the first credential in the list — same behavior as before for single-account
deployments, deterministic (first-synced-first-served) for multi-account.
**Scenarios**:
- **Single account** (most common today): `index["google"] = ["google_abc123"]`.
`load("google")` returns the only entry. Identical behavior to before.
- **Two accounts, same provider**: `index["google"] = ["google_abc123", "google_def456"]`.
`load("google")` returns first. `load_all_for_provider("google")` returns both.
Existing tools see no change; new APIs can enumerate.
- **Mixed providers**: `index["google"] = ["google_abc123"], index["slack"] = ["slack_xyz"]`.
Each provider resolves independently.
- **Credential removed from Aden**: On next `sync_all()`, `rebuild_provider_index()`
rebuilds from disk. The removed credential's `.enc` file is gone, so it drops from
the index naturally.
- **`exists()` check**: Validation calls `exists("google")` to check if credentials
are available before running health checks. Must return `True` if ANY Google account
exists, not just the last-synced one.
```python
# _index_provider — append, don't overwrite
def _index_provider(self, credential):
...
if provider_name not in self._provider_index:
self._provider_index[provider_name] = []
if credential.id not in self._provider_index[provider_name]:
self._provider_index[provider_name].append(credential.id)
# load — first match (backward compat)
def load(self, credential_id):
resolved_ids = self._provider_index.get(credential_id)
if resolved_ids:
for rid in resolved_ids:
if rid != credential_id:
result = self._load_by_id(rid)
if result is not None:
return result
return self._load_by_id(credential_id)
# NEW: enumerate all accounts
def load_all_for_provider(self, provider_name: str) -> list[CredentialObject]:
results = []
for cid in self._provider_index.get(provider_name, []):
cred = self._load_by_id(cid)
if cred:
results.append(cred)
return results
```
---
### Step 3: Preserve Aden metadata as identity
**File**: `core/framework/credentials/aden/provider.py`
**Business logic**: When a user connects a Google account through Aden's OAuth flow,
the Aden server stores metadata about the connected account — most importantly, the
email address. This metadata comes back in every API response as
`metadata: {"email": "alice@company.com"}`. Today, this metadata is present in
`AdenCredentialResponse.metadata` (the `from_dict()` parser already handles it) but
is never written into the `CredentialObject`'s key vault. It's silently dropped.
**Why Aden metadata is the primary identity source**: Aden captures identity at the
moment of OAuth authorization — the user explicitly grants access, and the Aden server
records who they are. This is more authoritative than health checks because:
1. It's captured at consent time, not at validation time
2. It works even if the health check endpoint is down
3. It's available immediately on first sync, before any health check runs
**When metadata arrives**: Two code paths create/update credentials from Aden responses:
1. **`_aden_response_to_credential()`** — first-time sync. The credential doesn't
exist locally yet. We're building it from scratch. Metadata should be written as
`_identity_*` keys in the initial key dict.
2. **`_update_credential_from_aden()`** — token refresh. The credential already exists.
The access token is updated. Metadata should be written/overwritten as `_identity_*`
keys on the existing credential object.
**Scenario — first sync**: User connects `alice@company.com` through Aden. Aden
returns `{access_token: "...", metadata: {email: "alice@company.com"}}`. The
credential is created with `_identity_email = "alice@company.com"`. Later,
`cred.identity.email` returns `"alice@company.com"`.
**Scenario — token refresh**: Alice's token expires. Aden refreshes it and returns
updated metadata. `_update_credential_from_aden()` updates the access token AND
refreshes `_identity_email`. If Alice changed her email (e.g., name change), the
identity stays current.
**Scenario — no metadata**: Some Aden integrations may not return metadata (e.g.,
a simple API key integration). The loop `for meta_key, meta_value in (metadata or {}).items()`
safely does nothing. The credential has no `_identity_*` keys, and `cred.identity`
returns `CredentialIdentity()` with `label == "unknown"`.
```python
# In _aden_response_to_credential, after building keys dict:
for meta_key, meta_value in (aden_response.metadata or {}).items():
if meta_value and isinstance(meta_value, str):
keys[f"_identity_{meta_key}"] = CredentialKey(
name=f"_identity_{meta_key}",
value=SecretStr(meta_value),
)
# In _update_credential_from_aden, after updating access_token:
for meta_key, meta_value in (aden_response.metadata or {}).items():
if meta_value and isinstance(meta_value, str):
credential.keys[f"_identity_{meta_key}"] = CredentialKey(
name=f"_identity_{meta_key}",
value=SecretStr(meta_value),
)
```
---
### Step 4: Extract identity from health checks
**File**: `tools/src/aden_tools/credentials/health_check.py`
**Business logic**: Health checks are the second identity source. Every agent startup
runs `validate_agent_credentials()` which calls provider-specific health check
endpoints. These endpoints return identity data as a side effect of validation:
| Health Check Endpoint | What It Returns | Identity We Extract |
|----------------------|----------------|-------------------|
| Gmail: `GET /users/me/profile` | `{emailAddress, messagesTotal, ...}` | `email = emailAddress` |
| Calendar: `GET /users/me/calendarList` | `{items: [{id, primary, ...}]}` | `email = primary calendar id` |
| Slack: `POST auth.test` | `{ok, team, user, bot_id, ...}` | `workspace = team, username = user` |
| GitHub: `GET /user` | `{login, id, name, ...}` | `username = login` |
| Discord: `GET /users/@me` | `{username, id, ...}` | `username = username` |
**Why health checks matter as an identity source**:
1. **Fallback when Aden metadata is missing**: Not all Aden integrations return
metadata. The health check always hits the actual service, so identity is always
available on success.
2. **Ground truth verification**: Aden metadata is captured at OAuth time. If the
user's email changed since then, the health check returns the CURRENT identity.
3. **Non-Aden credentials**: When credentials are configured via environment
variables (no Aden), health checks are the ONLY identity source. A dev sets
`GOOGLE_ACCESS_TOKEN` manually — the health check reveals whose token it is.
4. **Zero additional cost**: The health check API call is already happening. We
just need to parse the response body that's currently discarded after the
status code check.
**Design — `_extract_identity()` hook**: The base `OAuthBearerHealthChecker` gets
a new virtual method `_extract_identity(data: dict) -> dict[str, str]` that subclasses
override. The `check()` method calls it when the response is 200 OK:
```python
class OAuthBearerHealthChecker:
def _extract_identity(self, data: dict) -> dict[str, str]:
"""Override to extract identity fields from successful response."""
return {}
def check(self, access_token: str) -> HealthCheckResult:
...
if response.status_code == 200:
identity = {}
try:
data = response.json()
identity = self._extract_identity(data)
except Exception:
pass # Identity extraction is best-effort
return HealthCheckResult(
valid=True,
message=f"{self.service_name} credentials valid",
details={"identity": identity} if identity else {},
)
```
**Why `details["identity"]`**: The existing `HealthCheckResult` has a `details: dict`
field that's used ad-hoc by different checkers. By putting identity under a standardized
`"identity"` key, Step 5 can generically extract it without knowing which checker
ran. Existing `details` fields (`username`, `team`, `bot_id`) continue to exist
alongside — no breaking changes.
**Standalone checkers** (Slack, GitHub, Discord) don't extend `OAuthBearerHealthChecker`.
They already parse identity data into their `details` dict. For these, we simply add
an `"identity"` key with the structured fields alongside existing keys.
**Scenario — Gmail health check enriches a credential without Aden metadata**: A dev
sets `GOOGLE_ACCESS_TOKEN` as an env var. The credential has no `_identity_*` keys.
On startup, the Gmail health check calls `/users/me/profile`, gets
`{emailAddress: "dev@gmail.com"}`, returns `details={"identity": {"email": "dev@gmail.com"}}`.
Step 5 persists this. Now `cred.identity.email` works even without Aden.
**Scenario — health check fails**: Token is expired or revoked. Response is 401.
No identity extracted (identity extraction only runs on 200). The health check
returns `valid=False`. Step 5 skips persistence. The credential's existing identity
(if any, from Aden metadata) remains unchanged.
**Scenario — identity extraction throws**: The response body is malformed or missing
expected fields. The `try/except` in `check()` catches it. Health check still returns
`valid=True` (the token worked). Identity is just not extracted. Best-effort, never
blocks validation.
---
### Step 5: Persist identity during validation
**File**: `core/framework/credentials/validation.py`
**Business logic**: Steps 3 and 4 produce identity data. Step 5 is the bridge that
takes identity from health check results and persists it to the credential store.
This runs during `validate_agent_credentials()`, which is called at every agent startup.
**Why persist during validation**: Validation is the natural lifecycle hook because:
1. It runs on every agent startup (guaranteed execution)
2. It already has access to the credential store
3. It already runs health checks (identity is available in the result)
4. It runs BEFORE the agent executes (identity is available for system prompt injection)
**Flow**:
```
Agent startup
→ validate_agent_credentials()
→ for each credential:
→ check_credential_health(token) → HealthCheckResult
→ if result.valid AND result.details["identity"] exists:
→ cred_obj = store.get_credential(cred_id)
→ cred_obj.set_identity(**identity_data)
→ store.save_credential(cred_obj) ← persisted to disk
```
**Scenario — identity from health check augments Aden metadata**: Aden provides
`metadata.email = "alice@company.com"` (stored as `_identity_email` in Step 3).
The Slack health check returns `identity: {workspace: "Acme Corp", username: "hive-bot"}`.
Step 5 adds `_identity_workspace` and `_identity_username` to the Slack credential.
Now both credentials have rich identity data from their respective sources.
**Scenario — identity update on restart**: Between agent runs, the GitHub user
renamed from `old-username` to `new-username`. On next startup, the health check
returns `identity: {username: "new-username"}`. Step 5 calls `set_identity(username="new-username")`,
which overwrites `_identity_username`. The credential now reflects the current identity.
**Scenario — multiple accounts of same provider**: With the index fix (Step 2),
`validate_agent_credentials()` iterates over all credentials. Each Google account
gets its own health check. Each health check returns a different `emailAddress`.
Each identity is persisted to the correct `CredentialObject`. Account A gets
`_identity_email = "alice@co.com"`, Account B gets `_identity_email = "bob@co.com"`.
**Error handling**: Identity persistence is best-effort. If `get_credential()` fails
or `save_credential()` fails, the exception is caught and swallowed. The agent still
starts. The credential still works. It just won't have identity data for that account.
This is acceptable because identity is informational, not functional.
```python
if result.valid:
identity_data = result.details.get("identity")
if identity_data and isinstance(identity_data, dict):
try:
cred_obj = store.get_credential(cred_id, refresh_if_needed=False)
if cred_obj:
cred_obj.set_identity(**identity_data)
store.save_credential(cred_obj)
except Exception:
pass # Identity persistence is best-effort
```
---
### Step 6: Account listing & identity APIs
**Files**: `core/framework/credentials/store.py`, `tools/src/aden_tools/credentials/store_adapter.py`
**Business logic**: Steps 1-5 populate identity data. Step 6 exposes it through
clean APIs. Two layers need new methods:
1. **`CredentialStore`** (framework layer) — knows about `CredentialObject` and storage
2. **`CredentialStoreAdapter`** (tool boundary) — wraps the store with `CredentialSpec`-aware
APIs, sits in the MCP subprocess, consumed by tools
**Why two layers**: The store is a framework concept (core/). The adapter is a tools
concept (tools/). Tools never import from core directly. The adapter bridges the gap,
translating between credential IDs and spec names, handling the "is this credential
configured and available?" logic.
**APIs added to `CredentialStore`**:
- `list_accounts(provider_name)` — returns all accounts for a provider type with
their identities. Delegates to `storage.load_all_for_provider()` (Step 2). Returns
a list of dicts, not raw `CredentialObject`s, to avoid leaking secrets upstream.
- `get_credential_by_identity(provider_name, label)` — finds a specific account by
matching `cred.identity.label` against the provided label. This is the resolution
mechanism for future multi-account routing: "give me the token for alice@co.com."
**APIs added to `CredentialStoreAdapter`**:
- `get_identity(name)` — returns the identity dict for a named credential spec.
Used by tools that want to know whose token they're using for logging/display.
- `list_accounts(provider_name)` — delegates to store. Used by the `get_account_info`
MCP tool (Step 8).
- `get_all_account_info()` — iterates over all configured credential specs, collects
all accounts across all providers. Used to build the system prompt (Step 7).
Deduplicates by provider name to avoid listing the same provider's accounts twice
when multiple specs map to the same provider.
- `get_by_identity(provider_name, label)` — resolves a specific account's token by
identity label. Used by future multi-account routing (Step 9). Returns a raw token
string, not a `CredentialObject`.
**Scenario — system prompt building**: At agent startup, the runner calls
`adapter.get_all_account_info()`. The adapter iterates over specs:
`{"gmail": CredentialSpec(credential_id="google"), "gcal": CredentialSpec(credential_id="google"), "slack": CredentialSpec(...)}`.
It deduplicates by provider: `google` and `slack`. For `google`, `list_accounts("google")`
returns 2 accounts. For `slack`, 1 account. Result: 3 account entries for the system prompt.
**Scenario — identity-based routing (future)**: The LLM calls
`gmail_read_email(account="alice@co.com")`. The tool calls
`credentials.get_by_identity("google", "alice@co.com")`. The adapter delegates to
`store.get_credential_by_identity("google", "alice@co.com")` which scans all Google
credentials, finds the one where `identity.label == "alice@co.com"`, and returns
its access token. The right inbox is read.
```python
# CredentialStore
def list_accounts(self, provider_name: str) -> list[dict[str, Any]]:
if hasattr(self._storage, 'load_all_for_provider'):
creds = self._storage.load_all_for_provider(provider_name)
else:
cred = self.get_credential(provider_name)
creds = [cred] if cred else []
return [
{"credential_id": c.id, "provider": provider_name,
"identity": c.identity.to_dict(), "label": c.identity.label}
for c in creds
]
def get_credential_by_identity(self, provider_name: str, label: str) -> CredentialObject | None:
if hasattr(self._storage, 'load_all_for_provider'):
for cred in self._storage.load_all_for_provider(provider_name):
if cred.identity.label == label:
return cred
return None
```
```python
# CredentialStoreAdapter
def get_all_account_info(self) -> list[dict[str, Any]]:
accounts = []
seen: set[str] = set()
for name, spec in self._specs.items():
provider = spec.credential_id or name
if provider in seen or not self.is_available(name):
continue
seen.add(provider)
accounts.extend(self._store.list_accounts(provider))
return accounts
def get_by_identity(self, provider_name: str, label: str) -> str | None:
cred = self._store.get_credential_by_identity(provider_name, label)
return cred.get_default_key() if cred else None
```
---
### Step 7: Surface identity to LLM via system prompt
**Files**: `prompt_composer.py`, `executor.py`, `event_loop_node.py`, `node.py`, `runner.py`
**Business logic**: The LLM needs to know what accounts are connected so it can:
1. **Communicate clearly to the user**: "I checked alice@company.com's inbox and
found 3 unread messages" vs. "I checked the inbox and found 3 unread messages"
2. **Disambiguate operations**: When asked "check my emails," the LLM can respond
"You have 2 Google accounts connected: alice@company.com and bob@company.com.
Which would you like me to check?" (requires Step 9 routing, but awareness comes first)
3. **Prevent hallucination**: Without account info, the LLM might invent account
names or assume capabilities it doesn't have. With the accounts prompt, it knows
exactly what's available.
4. **Cross-account reasoning**: "Forward the email from alice's inbox to bob's inbox"
requires knowing both accounts exist and which is which.
**Where it sits in the three-layer prompt**:
```
Layer 1 — Identity: "You are a thorough email management agent."
Accounts: "Connected accounts:
- google: alice@company.com (email: alice@company.com)
- google: bob@company.com (email: bob@company.com)
- slack: Acme Corp (workspace: Acme Corp, username: hive-bot)"
Layer 2 — Narrative: "We've triaged 15 emails so far..."
Layer 3 — Focus: "Your current task: categorize remaining unread emails"
```
Accounts sit between identity (static personality) and narrative (dynamic state)
because connected accounts are semi-static — they don't change during a session but
are deployment-specific (different from the agent definition).
**Injection path through the framework**:
```
AgentRunner._setup()
→ CredentialStoreAdapter.get_all_account_info()
→ build_accounts_prompt(accounts) ← new function in prompt_composer.py
→ GraphExecutor(accounts_prompt=...) ← new init param
→ NodeContext(accounts_prompt=...) ← new field
→ compose_system_prompt(..., accounts_prompt=...) ← new param
```
**Why it flows through `NodeContext`**: For the first node in a graph (or an isolated
`EventLoopNode`), the system prompt is built in `EventLoopNode.execute()`, not through
the continuous transition path. `NodeContext.accounts_prompt` carries the data to
both paths:
- **Continuous transition**: `compose_system_prompt()` in the executor uses
`self.accounts_prompt` directly
- **First node / isolated node**: `EventLoopNode.execute()` reads `ctx.accounts_prompt`
and appends it to the system prompt
**Scenario — no credentials**: An agent with no external integrations (pure LLM
reasoning, no tools). `get_all_account_info()` returns `[]`. `build_accounts_prompt([])`
returns `""`. The accounts block is omitted from the system prompt. Zero impact.
**Scenario — single account**: One Google account. System prompt shows
`"Connected accounts:\n- google: alice@company.com (email: alice@company.com)"`.
The LLM knows who it's operating as.
**Scenario — unknown identity**: A credential exists but has no `_identity_*` keys
(maybe Aden didn't provide metadata and health checks haven't run yet). `identity.label`
returns `"unknown"`. The prompt shows `"- google: unknown"`. Better than nothing —
the LLM knows Google is connected, just not whose account.
```python
def build_accounts_prompt(accounts: list[dict[str, Any]]) -> str:
if not accounts:
return ""
lines = ["Connected accounts:"]
for acct in accounts:
provider = acct.get("provider", "unknown")
label = acct.get("label", "unknown")
identity = acct.get("identity", {})
detail_parts = [f"{k}: {v}" for k, v in identity.items() if v]
detail = f" ({', '.join(detail_parts)})" if detail_parts else ""
lines.append(f"- {provider}: {label}{detail}")
return "\n".join(lines)
```
---
### Step 8: `get_account_info` MCP tool
**New directory**: `tools/src/aden_tools/tools/account_info_tool/`
**Business logic**: Step 7 gives the LLM passive awareness (system prompt). Step 8
gives the LLM active introspection — it can call `get_account_info()` to query
connected accounts at runtime, even mid-conversation.
**Why both passive and active**: The system prompt provides context at conversation
start. But in long-running agents with many tools, the system prompt may get
compacted (truncated during context management). The MCP tool ensures the LLM can
always re-discover account info even after compaction.
**Use cases**:
- **User asks "what accounts are connected?"**: LLM calls `get_account_info()`,
formats the response for the user.
- **LLM needs to decide which account to use**: Before sending an email, the LLM
calls `get_account_info(provider="google")` to see which Gmail accounts are
available, then asks the user which one to send from.
- **Dynamic account discovery**: In a long-running session, accounts might be
added/revoked (Aden dashboard). The tool provides current state vs. the stale
system prompt.
- **Debugging/transparency**: The user can ask "which Slack workspace are you
connected to?" and get a precise answer.
**API design**:
```python
@mcp.tool()
def get_account_info(provider: str = "") -> dict:
"""List connected accounts and their identities.
Call with no arguments to see all connected accounts.
Call with provider="google" to filter by provider type.
Returns account IDs, provider types, and identity labels
(email, username, workspace) for each connected account.
"""
if credentials is None:
return {"accounts": [], "message": "No credential store configured"}
if provider:
accounts = credentials.list_accounts(provider)
else:
accounts = credentials.get_all_account_info()
return {"accounts": accounts, "count": len(accounts)}
```
**Response example**:
```json
{
"accounts": [
{"credential_id": "google_abc123", "provider": "google",
"identity": {"email": "alice@company.com"}, "label": "alice@company.com"},
{"credential_id": "google_def456", "provider": "google",
"identity": {"email": "bob@company.com"}, "label": "bob@company.com"},
{"credential_id": "slack_xyz", "provider": "slack",
"identity": {"workspace": "Acme Corp", "username": "hive-bot"},
"label": "Acme Corp"}
],
"count": 3
}
```
Register in `tools/src/aden_tools/tools/__init__.py` alongside existing tools.
---
### Step 9: Multi-account routing extension point (design only, no code)
**Business logic**: Steps 1-8 build the foundation. Step 9 designs (but does not
implement) the per-tool-call account selection mechanism. This is the endgame:
when the LLM calls `gmail_read_email(account="alice@co.com")`, the right token
is used.
**Why design-only in this PR**: Multi-account routing requires changes to every
tool's `_get_token()` function and introduces the `account` parameter across all
tool signatures. This is a significant surface area change that should be a
separate PR with its own testing. The foundation from Steps 1-8 makes it a
straightforward addition.
**How it will work — the full flow**:
1. **LLM discovers accounts**: Via system prompt (Step 7) or `get_account_info` tool
(Step 8), the LLM knows `alice@company.com` and `bob@company.com` are connected.
2. **User says "check alice's inbox"**: The LLM calls
`gmail_read_email(account="alice@company.com")`.
3. **Tool resolves account**: `_get_token("alice@company.com")` calls
`credentials.get_by_identity("google", "alice@company.com")`.
4. **Store resolves credential**: `get_credential_by_identity("google", "alice@company.com")`
scans all Google credentials, finds the one where `identity.label == "alice@company.com"`,
returns its access token.
5. **API call with correct token**: The tool uses Alice's token to call the Gmail API.
The right inbox is read.
**Pinned single-account agents**: For agents that should ALWAYS use a specific account
(e.g., a shared support inbox), the `account` parameter becomes a `CONTEXT_PARAM` in
`ToolRegistry`. It's stripped from the LLM schema (the LLM can't override it) and
auto-injected at call time from `NodeSpec` or `GraphSpec` configuration. This follows
the exact same pattern as `data_dir` — proven, concurrency-safe, framework-native.
**Why `CredentialIdentity.label` is the stable routing key**:
- It's human-readable (email addresses, usernames)
- It's deterministic (computed from `_identity_*` keys)
- It matches what the LLM sees in the system prompt
- It survives credential refresh (identity doesn't change when tokens rotate)
- It's unique within a provider (two Google accounts always have different emails)
---
## How This Works with Exported/Template Agents
### Agent definition (no changes)
Exported agents in `exports/` declare tools via `NodeSpec.tools` and MCP servers via
`mcp_servers.json`. They don't know about credentials — this is by design. Credential
specs (`CredentialSpec.tools`) provide the external mapping from tool name to credential.
**Scenario — same agent, different deployments**: The "Email Triage" agent template
is used by 3 customers. Customer A has 1 Gmail account. Customer B has 5. Customer C
has 3 Gmail and 2 Outlook. The agent definition is identical for all three. Only
the Aden API key (and thus the available credentials) differs.
### Agent instance (accounts_prompt injection)
When `AgentRunner.load()` instantiates an agent:
1. `validate_agent_credentials()` runs — syncs Aden, checks presence/health
2. Identity is persisted during validation (Step 5)
3. `_setup()` collects `accounts_prompt` via `CredentialStoreAdapter.get_all_account_info()`
4. Passes to `GraphExecutor(accounts_prompt=...)``compose_system_prompt()`
The agent definition doesn't need to change. Identity flows through the existing
runtime wiring.
### MCP subprocess (independent adapter)
The MCP subprocess creates its own `CredentialStoreAdapter.default()` at startup.
This triggers an independent `sync_all()` from Aden. With the index fix (Step 2),
all accounts are preserved. The adapter's new methods (`list_accounts()`,
`get_all_account_info()`, `get_by_identity()`) are available to tools in the subprocess.
**Why independent sync is correct**: The MCP subprocess runs in a separate process
with its own memory space. It cannot share the parent's `CredentialStore`. Both
processes sync from the same Aden server (same API key), so they see the same
credentials. The disk-level `EncryptedFileStorage` handles concurrent access safely
(each read is atomic file read, writes use temp+rename).
### ToolRegistry bridge (future routing)
When multi-account routing is implemented (Step 9), the `account` parameter will be
added to `CONTEXT_PARAMS`. `ToolRegistry._convert_mcp_tool_to_framework_tool()` will
strip it from LLM schema (line 467). `make_mcp_executor()` will inject it at call time
(line 421). This follows the exact same pattern as `data_dir`.
---
## Files Modified (Summary)
| # | File | Changes |
|---|------|---------|
| 1 | `core/framework/credentials/models.py` | `CredentialIdentity`, `identity` property, `set_identity()`, `provider_type` |
| 2 | `core/framework/credentials/aden/storage.py` | `_provider_index: dict[str, list[str]]`, `load_all_for_provider()`, fix `exists()`, `rebuild_provider_index()` |
| 3 | `core/framework/credentials/aden/provider.py` | Persist `metadata` as `_identity_*` keys in both `_aden_response_to_credential` and `_update_credential_from_aden` |
| 4 | `tools/src/aden_tools/credentials/health_check.py` | `_extract_identity()` hook on `OAuthBearerHealthChecker`, overrides per checker, `identity` key in standalone checker `details` |
| 5 | `core/framework/credentials/validation.py` | Persist identity from health check `details["identity"]` via `set_identity()` |
| 6 | `core/framework/credentials/store.py` | `list_accounts()`, `get_credential_by_identity()` |
| 7 | `tools/src/aden_tools/credentials/store_adapter.py` | `get_identity()`, `list_accounts()`, `get_all_account_info()`, `get_by_identity()` |
| 8 | `core/framework/graph/prompt_composer.py` | `build_accounts_prompt()`, `accounts_prompt` param on `compose_system_prompt()` |
| 9 | `core/framework/graph/node.py` | `accounts_prompt: str = ""` on `NodeContext` |
| 10 | `core/framework/graph/executor.py` | `accounts_prompt` init param, pass to `compose_system_prompt()` and `_build_context()` |
| 11 | `core/framework/graph/event_loop_node.py` | Append `accounts_prompt` for first node system prompt |
| 12 | `core/framework/runner/runner.py` | Collect accounts info in `_setup()`, pass to executor |
| 13 | `tools/src/aden_tools/tools/account_info_tool/` | New `get_account_info` MCP tool |
| 14 | `tools/src/aden_tools/tools/__init__.py` | Register account info tool |
---
## Verification
1. **Multi-index**: Sync 2 Google accounts → both in `_provider_index["google"]` (not overwritten)
2. **Identity model**: `cred.identity.email` returns email, `cred.identity.label` returns best label
3. **Health check identity**: `GoogleGmailHealthChecker.check(token)``result.details["identity"]["email"]`
4. **Persistence**: After validation, credential on disk has `_identity_email` key
5. **Account listing**: `adapter.list_accounts("google")` → 2 accounts with distinct identities
6. **System prompt**: `compose_system_prompt(accounts_prompt=...)` includes "Connected accounts"
7. **MCP tool**: `get_account_info(provider="google")` returns both accounts with labels
8. **Backward compat**: `credentials.get("google")` still returns single token string
9. **Existing tests**: `PYTHONPATH=core:tools/src python -m pytest tools/tests/ -x -q -k "credential"`
File diff suppressed because it is too large Load Diff
-992
View File
@@ -1,992 +0,0 @@
# Credential Store Usage Guide
This guide covers how to use the Hive credential store for managing API keys, OAuth2 tokens, and custom credentials in your agents and tools.
## Table of Contents
- [Quick Start](#quick-start)
- [Core Concepts](#core-concepts)
- [Basic Usage](#basic-usage)
- [Template Resolution](#template-resolution)
- [Storage Backends](#storage-backends)
- [Using OAuth2 Provider](#using-oauth2-provider)
- [Implementing Custom Providers](#implementing-custom-providers)
- [Testing with Credentials](#testing-with-credentials)
- [Migration from CredentialManager](#migration-from-credentialmanager)
- [Security Best Practices](#security-best-practices)
---
## Quick Start
```python
from core.framework.credentials import CredentialStore, InMemoryStorage
# Create a store with in-memory storage (for development)
store = CredentialStore(storage=InMemoryStorage())
# Save a simple API key
store.save_api_key("brave_search", "your-api-key-here")
# Retrieve the credential
api_key = store.get("brave_search")
# Use template resolution for HTTP headers
headers = store.resolve_headers({
"X-Subscription-Token": "{{brave_search.api_key}}"
})
# Result: {"X-Subscription-Token": "your-api-key-here"}
```
---
## Core Concepts
### Key-Vault Structure
Credentials are stored as **objects** containing one or more **keys**:
```
brave_search (CredentialObject)
├── api_key: "BSAKxxxxx"
github_oauth (CredentialObject)
├── access_token: "ghp_xxxxx"
├── refresh_token: "ghr_xxxxx"
└── expires_at: 2024-01-15T10:00:00Z
```
### Bipartisan Model
The credential store follows a **bipartisan model**:
- **Store**: Only stores credential values
- **Tools**: Define how credentials are used (headers, query params, etc.)
This separation keeps the store simple and lets each tool specify its exact requirements.
### Components
| Component | Purpose |
|-----------|---------|
| `CredentialStore` | Main orchestrator for all credential operations |
| `CredentialObject` | A credential with one or more keys |
| `CredentialKey` | A single key-value pair with optional expiration |
| `CredentialStorage` | Backend for persisting credentials |
| `CredentialProvider` | Handles credential lifecycle (refresh, validate) |
| `TemplateResolver` | Resolves `{{cred.key}}` patterns |
---
## Basic Usage
### Creating a Credential Store
```python
from core.framework.credentials import (
CredentialStore,
EncryptedFileStorage,
EnvVarStorage,
InMemoryStorage,
)
# Option 1: Encrypted file storage (recommended for production)
store = CredentialStore.with_encrypted_storage("~/.hive/credentials")
# Option 2: Environment variable storage (backward compatible)
store = CredentialStore.with_env_storage({
"brave_search": "BRAVE_SEARCH_API_KEY",
"openai": "OPENAI_API_KEY",
})
# Option 3: In-memory storage (for testing/development)
store = CredentialStore(storage=InMemoryStorage())
# Option 4: Custom storage configuration
storage = EncryptedFileStorage(
base_path="~/.hive/credentials",
key_env_var="HIVE_CREDENTIAL_KEY" # Encryption key from env
)
store = CredentialStore(storage=storage)
```
### Saving Credentials
```python
# Simple API key
store.save_api_key("brave_search", "your-api-key")
# Multi-key credential (e.g., OAuth2)
from core.framework.credentials import CredentialObject, CredentialKey, CredentialType
from pydantic import SecretStr
from datetime import datetime, timedelta, timezone
credential = CredentialObject(
id="github_oauth",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("ghp_xxxxxxxxxxxx"),
expires_at=datetime.now(timezone.utc) + timedelta(hours=1)
),
"refresh_token": CredentialKey(
name="refresh_token",
value=SecretStr("ghr_xxxxxxxxxxxx")
),
},
provider_id="oauth2",
auto_refresh=True,
)
store.save_credential(credential)
```
### Retrieving Credentials
```python
# Get the default key value (api_key, access_token, or first key)
api_key = store.get("brave_search")
# Get a specific key
access_token = store.get_key("github_oauth", "access_token")
refresh_token = store.get_key("github_oauth", "refresh_token")
# Get the full credential object
credential = store.get_credential("github_oauth")
if credential:
print(f"Type: {credential.credential_type}")
print(f"Keys: {list(credential.keys.keys())}")
print(f"Auto-refresh: {credential.auto_refresh}")
# Check if credential exists and is available
if store.is_available("brave_search"):
# Use the credential
pass
```
### Deleting Credentials
```python
# Delete a credential
deleted = store.delete_credential("old_api_key")
if deleted:
print("Credential deleted")
```
---
## Template Resolution
The credential store supports template patterns for injecting credentials into HTTP requests.
### Syntax
```
{{credential_id}} -> Returns default key
{{credential_id.key_name}} -> Returns specific key
```
### Resolving Headers
```python
# Define headers with credential templates
header_templates = {
"Authorization": "Bearer {{github_oauth.access_token}}",
"X-API-Key": "{{brave_search.api_key}}",
"X-Custom": "{{custom_cred.token}}"
}
# Resolve to actual values
headers = store.resolve_headers(header_templates)
# Result: {
# "Authorization": "Bearer ghp_xxxxxxxxxxxx",
# "X-API-Key": "BSAKxxxxxxxxxxxx",
# "X-Custom": "actual-token-value"
# }
# Use with httpx/requests
import httpx
response = httpx.get("https://api.example.com/data", headers=headers)
```
### Resolving Query Parameters
```python
params = store.resolve_params({
"api_key": "{{brave_search.api_key}}",
"client_id": "{{oauth_app.client_id}}"
})
```
### Resolving Arbitrary Strings
```python
# Resolve any string containing templates
url = store.resolve("https://api.example.com?key={{api_cred.key}}")
```
### Handling Missing Credentials
```python
# By default, missing credentials raise an error
try:
headers = store.resolve_headers({"Auth": "{{missing.key}}"})
except CredentialNotFoundError as e:
print(f"Missing credential: {e}")
# Use fail_on_missing=False to leave templates unresolved
headers = store.resolve_headers(
{"Auth": "{{missing.key}}"},
fail_on_missing=False
)
# Result: {"Auth": "{{missing.key}}"}
```
---
## Storage Backends
### EncryptedFileStorage (Recommended)
Encrypts credentials at rest using Fernet (AES-128-CBC + HMAC).
```python
from core.framework.credentials import EncryptedFileStorage
# The encryption key is read from HIVE_CREDENTIAL_KEY env var
storage = EncryptedFileStorage("~/.hive/credentials")
# Or provide the key directly (32-byte Fernet key)
storage = EncryptedFileStorage(
base_path="~/.hive/credentials",
encryption_key=b"your-32-byte-fernet-key-here..."
)
```
**Directory structure:**
```
~/.hive/credentials/
├── credentials/
│ ├── brave_search.enc # Encrypted credential JSON
│ └── github_oauth.enc
└── metadata/
└── index.json # Unencrypted index
```
**Generate an encryption key:**
```python
from cryptography.fernet import Fernet
key = Fernet.generate_key()
print(f"HIVE_CREDENTIAL_KEY={key.decode()}")
```
### EnvVarStorage (Backward Compatible)
Reads credentials from environment variables. **Read-only** - cannot save credentials.
```python
from core.framework.credentials import EnvVarStorage
storage = EnvVarStorage(
env_mapping={
"brave_search": "BRAVE_SEARCH_API_KEY",
"openai": "OPENAI_API_KEY",
}
)
# Credentials are read from environment
# export BRAVE_SEARCH_API_KEY=your-key
```
### CompositeStorage (Layered)
Combines multiple storage backends with fallback.
```python
from core.framework.credentials import CompositeStorage, EncryptedFileStorage, EnvVarStorage
storage = CompositeStorage(
primary=EncryptedFileStorage("~/.hive/credentials"),
fallbacks=[
EnvVarStorage({"brave_search": "BRAVE_SEARCH_API_KEY"})
]
)
# Writes go to primary (encrypted files)
# Reads check primary first, then fallbacks (env vars)
```
### HashiCorp Vault Storage
For enterprise deployments with HashiCorp Vault.
```python
from core.framework.credentials.vault import HashiCorpVaultStorage
storage = HashiCorpVaultStorage(
vault_url="https://vault.example.com",
token="hvs.xxxxx", # Or use VAULT_TOKEN env var
mount_point="secret",
path_prefix="hive/credentials"
)
```
---
## Using OAuth2 Provider
The OAuth2 provider handles token lifecycle including automatic refresh.
### Setup
```python
from core.framework.credentials import CredentialStore, InMemoryStorage
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config
# Configure OAuth2
config = OAuth2Config(
token_url="https://oauth.example.com/token",
authorization_url="https://oauth.example.com/authorize", # Optional
client_id="your-client-id",
client_secret="your-client-secret",
default_scopes=["read", "write"],
)
# Create provider
provider = BaseOAuth2Provider(config)
# Create store with provider
store = CredentialStore(
storage=InMemoryStorage(),
providers=[provider],
)
```
### Client Credentials Flow (Server-to-Server)
```python
# Get a token using client credentials
token = provider.client_credentials_grant(scopes=["api.read"])
# Save to store
from core.framework.credentials import CredentialObject, CredentialKey, CredentialType
from pydantic import SecretStr
credential = CredentialObject(
id="service_account",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr(token.access_token),
expires_at=token.expires_at
),
},
provider_id="oauth2",
auto_refresh=True,
)
store.save_credential(credential)
```
### Refresh Token Flow
```python
# Save credential with refresh token
credential = CredentialObject(
id="user_oauth",
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr("ghp_xxxx"),
expires_at=datetime.now(timezone.utc) + timedelta(hours=1)
),
"refresh_token": CredentialKey(
name="refresh_token",
value=SecretStr("ghr_xxxx")
),
},
provider_id="oauth2",
auto_refresh=True,
)
store.save_credential(credential)
# When you retrieve the credential, it auto-refreshes if expired
token = store.get("user_oauth") # Automatically refreshed if needed
# Or manually refresh
store.refresh_credential("user_oauth")
```
### Token Lifecycle Manager
For more control over token lifecycle:
```python
from core.framework.credentials.oauth2 import TokenLifecycleManager
from datetime import timedelta
manager = TokenLifecycleManager(
credential_id="my_oauth",
provider=provider,
store=store,
refresh_buffer=timedelta(minutes=5), # Refresh 5 min before expiry
)
# Acquire token (refreshes if needed)
token = await manager.acquire_token()
# Use the token
headers = {"Authorization": f"Bearer {token.access_token}"}
```
---
## Implementing Custom Providers
Custom providers let you integrate with proprietary authentication systems.
### Provider Interface
```python
from abc import ABC, abstractmethod
from typing import List
from core.framework.credentials import CredentialObject, CredentialType
class CredentialProvider(ABC):
"""Abstract base for credential providers."""
@property
@abstractmethod
def provider_id(self) -> str:
"""Unique identifier for this provider."""
pass
@property
@abstractmethod
def supported_types(self) -> List[CredentialType]:
"""Credential types this provider handles."""
pass
@abstractmethod
def refresh(self, credential: CredentialObject) -> CredentialObject:
"""Refresh the credential and return updated version."""
pass
@abstractmethod
def validate(self, credential: CredentialObject) -> bool:
"""Check if credential is still valid."""
pass
def should_refresh(self, credential: CredentialObject) -> bool:
"""Determine if credential needs refresh (optional override)."""
# Default: check expiration with 5-minute buffer
...
def revoke(self, credential: CredentialObject) -> bool:
"""Revoke credential (optional, default returns False)."""
return False
```
### Example: Custom API Provider
```python
from datetime import datetime, timedelta, timezone
from typing import List
from pydantic import SecretStr
from core.framework.credentials import (
CredentialKey,
CredentialObject,
CredentialProvider,
CredentialRefreshError,
CredentialType,
)
class MyCustomProvider(CredentialProvider):
"""
Custom provider for MyService API tokens.
MyService issues tokens that expire after 24 hours and can be
refreshed using the original API key.
"""
def __init__(self, base_url: str = "https://api.myservice.com"):
self.base_url = base_url
@property
def provider_id(self) -> str:
return "myservice"
@property
def supported_types(self) -> List[CredentialType]:
return [CredentialType.CUSTOM]
def refresh(self, credential: CredentialObject) -> CredentialObject:
"""Refresh the access token using the API key."""
import httpx
api_key = credential.get_key("api_key")
if not api_key:
raise CredentialRefreshError(
f"Credential '{credential.id}' missing api_key for refresh"
)
# Call MyService API to get new token
try:
response = httpx.post(
f"{self.base_url}/auth/token",
headers={"X-API-Key": api_key},
timeout=30,
)
response.raise_for_status()
data = response.json()
except httpx.HTTPError as e:
raise CredentialRefreshError(f"Token refresh failed: {e}") from e
# Update credential with new token
credential.set_key(
"access_token",
data["access_token"],
expires_at=datetime.now(timezone.utc) + timedelta(hours=24),
)
credential.last_refreshed = datetime.now(timezone.utc)
return credential
def validate(self, credential: CredentialObject) -> bool:
"""Check if access_token exists and is not expired."""
access_key = credential.keys.get("access_token")
if access_key is None:
return False
return not access_key.is_expired
def should_refresh(self, credential: CredentialObject) -> bool:
"""Refresh if token expires within 1 hour."""
access_key = credential.keys.get("access_token")
if access_key is None or access_key.expires_at is None:
return False
buffer = timedelta(hours=1)
return datetime.now(timezone.utc) >= (access_key.expires_at - buffer)
def revoke(self, credential: CredentialObject) -> bool:
"""Revoke the access token."""
import httpx
access_token = credential.get_key("access_token")
if not access_token:
return False
try:
response = httpx.post(
f"{self.base_url}/auth/revoke",
headers={"Authorization": f"Bearer {access_token}"},
timeout=30,
)
return response.status_code == 200
except httpx.HTTPError:
return False
```
### Registering Custom Providers
```python
from core.framework.credentials import CredentialStore, InMemoryStorage
# Create store with custom provider
provider = MyCustomProvider(base_url="https://api.myservice.com")
store = CredentialStore(
storage=InMemoryStorage(),
providers=[provider],
)
# Or register after creation
store.register_provider(provider)
# Save a credential that uses this provider
credential = CredentialObject(
id="myservice_prod",
credential_type=CredentialType.CUSTOM,
keys={
"api_key": CredentialKey(
name="api_key",
value=SecretStr("my-permanent-api-key")
),
},
provider_id="myservice", # Links to our custom provider
auto_refresh=True,
)
store.save_credential(credential)
# The store will use MyCustomProvider for refresh/validate
token = store.get("myservice_prod") # Auto-refreshes if needed
```
### Example: Extending OAuth2 for a Specific Service
```python
from core.framework.credentials.oauth2 import BaseOAuth2Provider, OAuth2Config, OAuth2Token
class GitHubOAuth2Provider(BaseOAuth2Provider):
"""GitHub-specific OAuth2 provider with custom scopes handling."""
def __init__(self, client_id: str, client_secret: str):
config = OAuth2Config(
token_url="https://github.com/login/oauth/access_token",
authorization_url="https://github.com/login/oauth/authorize",
client_id=client_id,
client_secret=client_secret,
default_scopes=["repo", "read:user"],
)
super().__init__(config)
@property
def provider_id(self) -> str:
return "github_oauth2"
def _parse_token_response(self, response_data: dict) -> OAuth2Token:
"""GitHub returns scope as space-separated string."""
token = super()._parse_token_response(response_data)
# GitHub-specific: tokens don't expire unless revoked
# But we set a reasonable refresh interval
if token.expires_at is None:
token.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
return token
def validate(self, credential: CredentialObject) -> bool:
"""Validate by making a test API call to GitHub."""
import httpx
access_token = credential.get_key("access_token")
if not access_token:
return False
try:
response = httpx.get(
"https://api.github.com/user",
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github+json",
},
timeout=10,
)
return response.status_code == 200
except httpx.HTTPError:
return False
```
---
## Testing with Credentials
### Using the Testing Factory
```python
from core.framework.credentials import CredentialStore
# Create a test store with mock credentials
store = CredentialStore.for_testing({
"brave_search": {"api_key": "test-brave-key"},
"github_oauth": {
"access_token": "test-github-token",
"refresh_token": "test-refresh-token",
},
})
# Use in tests
def test_my_tool():
api_key = store.get("brave_search")
assert api_key == "test-brave-key"
headers = store.resolve_headers({
"Authorization": "Bearer {{github_oauth.access_token}}"
})
assert headers["Authorization"] == "Bearer test-github-token"
```
### Using with CredentialStoreAdapter (Backward Compatible)
```python
from aden_tools.credentials import CredentialStoreAdapter
# For testing existing tools
credentials = CredentialStoreAdapter.for_testing({
"brave_search": "test-key",
"openai": "test-openai-key",
})
# Existing API works
assert credentials.get("brave_search") == "test-key"
credentials.validate_for_tools(["web_search"]) # No error
```
### Mocking in Unit Tests
```python
import pytest
from unittest.mock import MagicMock, patch
def test_tool_with_mocked_store():
# Create a mock store
mock_store = MagicMock()
mock_store.get.return_value = "mocked-api-key"
mock_store.resolve_headers.return_value = {
"Authorization": "Bearer mocked-token"
}
# Inject into your tool
with patch("my_tool.credential_store", mock_store):
result = my_tool.make_api_call()
mock_store.get.assert_called_once_with("api_credential")
```
---
## Migration from CredentialManager
If you're using the existing `CredentialManager`, migration is straightforward.
### Option 1: Use the Adapter (No Code Changes)
```python
# Before
from aden_tools.credentials import CredentialManager
credentials = CredentialManager()
# After - using adapter with new store backend
from aden_tools.credentials import CredentialStoreAdapter
from core.framework.credentials import CredentialStore
store = CredentialStore.with_encrypted_storage("~/.hive/credentials")
credentials = CredentialStoreAdapter(store)
# All existing code works unchanged
api_key = credentials.get("brave_search")
credentials.validate_for_tools(["web_search"])
```
### Option 2: Use Environment Storage (Identical Behavior)
```python
from aden_tools.credentials import CredentialStoreAdapter
# Creates adapter backed by environment variables
credentials = CredentialStoreAdapter.with_env_storage()
# Behaves exactly like original CredentialManager
api_key = credentials.get("brave_search")
```
### Option 3: Gradual Migration
```python
from aden_tools.credentials import CredentialStoreAdapter
from core.framework.credentials import CredentialStore, CompositeStorage, EncryptedFileStorage, EnvVarStorage
# Use encrypted storage as primary, env vars as fallback
storage = CompositeStorage(
primary=EncryptedFileStorage("~/.hive/credentials"),
fallbacks=[EnvVarStorage({"brave_search": "BRAVE_SEARCH_API_KEY"})]
)
store = CredentialStore(storage=storage)
credentials = CredentialStoreAdapter(store)
# New credentials go to encrypted storage
# Old env var credentials still work as fallback
```
---
## Security Best Practices
### 1. Use Encrypted Storage in Production
```python
# Always use EncryptedFileStorage for production
store = CredentialStore.with_encrypted_storage("~/.hive/credentials")
```
### 2. Protect the Encryption Key
```bash
# Set encryption key as environment variable
export HIVE_CREDENTIAL_KEY="your-fernet-key"
# Or use a secrets manager
export HIVE_CREDENTIAL_KEY=$(vault kv get -field=key secret/hive/credential-key)
```
### 3. Use SecretStr for Values
```python
from pydantic import SecretStr
# SecretStr prevents accidental logging
key = CredentialKey(
name="api_key",
value=SecretStr("sensitive-value") # Won't appear in logs
)
# Explicitly extract when needed
actual_value = key.get_secret_value()
```
### 4. Set Appropriate Expiration
```python
# Always set expiration for tokens
credential.set_key(
"access_token",
token_value,
expires_at=datetime.now(timezone.utc) + timedelta(hours=1)
)
```
### 5. Enable Auto-Refresh
```python
credential = CredentialObject(
id="my_oauth",
auto_refresh=True, # Automatically refresh before expiry
provider_id="oauth2",
...
)
```
### 6. Validate Before Use
```python
# Check credential validity before making API calls
if not store.is_available("api_credential"):
raise RuntimeError("Required credential not available")
# Or use validation
errors = store.validate_for_usage("api_credential")
if errors:
raise RuntimeError(f"Credential validation failed: {errors}")
```
### 7. Use Template Resolution
```python
# Don't interpolate secrets manually
# Bad:
headers = {"Authorization": f"Bearer {store.get('token')}"}
# Good - uses template resolution which handles errors gracefully:
headers = store.resolve_headers({
"Authorization": "Bearer {{my_oauth.access_token}}"
})
```
---
## API Reference
### CredentialStore
| Method | Description |
|--------|-------------|
| `get(credential_id)` | Get default key value |
| `get_key(credential_id, key_name)` | Get specific key value |
| `get_credential(credential_id)` | Get full credential object |
| `save_credential(credential)` | Save credential to storage |
| `save_api_key(id, value)` | Convenience for simple API keys |
| `delete_credential(credential_id)` | Delete a credential |
| `is_available(credential_id)` | Check if credential exists and has value |
| `resolve(template)` | Resolve template string |
| `resolve_headers(headers)` | Resolve templates in headers dict |
| `resolve_params(params)` | Resolve templates in params dict |
| `refresh_credential(credential_id)` | Manually refresh a credential |
| `register_provider(provider)` | Register a custom provider |
| `for_testing(credentials)` | Create test store with mock data |
| `with_encrypted_storage(path)` | Create store with encrypted files |
| `with_env_storage(mapping)` | Create store with env var backend |
### CredentialObject
| Property/Method | Description |
|-----------------|-------------|
| `id` | Unique identifier |
| `credential_type` | Type (API_KEY, OAUTH2, etc.) |
| `keys` | Dict of CredentialKey objects |
| `get_key(name)` | Get key value by name |
| `set_key(name, value, ...)` | Set or update a key |
| `has_key(name)` | Check if key exists |
| `get_default_key()` | Get default key value |
| `needs_refresh` | True if any key is expired |
| `is_valid` | True if has valid, non-expired key |
| `auto_refresh` | Whether to auto-refresh |
| `provider_id` | ID of provider for lifecycle |
### CredentialProvider
| Method | Description |
|--------|-------------|
| `provider_id` | Unique identifier (property) |
| `supported_types` | List of supported CredentialTypes (property) |
| `refresh(credential)` | Refresh and return updated credential |
| `validate(credential)` | Check if credential is valid |
| `should_refresh(credential)` | Check if refresh is needed |
| `revoke(credential)` | Revoke credential (optional) |
---
## Troubleshooting
### "Unknown credential" Error
```python
# Error: KeyError: "Unknown credential 'my_cred'"
# Solution: Check if credential exists
if store.get_credential("my_cred") is None:
print("Credential not found - need to save it first")
```
### "Credential not found" in Templates
```python
# Error: CredentialNotFoundError when resolving templates
# Solution 1: Ensure credential is saved
store.save_api_key("my_cred", "value")
# Solution 2: Use fail_on_missing=False
headers = store.resolve_headers(templates, fail_on_missing=False)
```
### Encryption Key Issues
```python
# Error: "Failed to decrypt credential"
# Solution: Ensure HIVE_CREDENTIAL_KEY matches what was used to encrypt
# If key is lost, credentials must be re-created
```
### Provider Not Found
```python
# Warning: "No provider found for credential 'x'"
# Solution: Register the provider or set provider_id=None for static credentials
store.register_provider(MyProvider())
# Or use static provider (default)
credential.provider_id = "static" # or None
```
---
## Further Reading
- [Credential Store Design Document](credential-store-design.md)
- [OAuth2 RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749)
- [Fernet Encryption](https://cryptography.io/en/latest/fernet/)
-552
View File
@@ -1,552 +0,0 @@
# Credential System: Complete Code Path Analysis
## Architecture Overview
```
┌──────────────┐
│ AgentRunner │ runner.py:_validate_credentials()
└──────┬───────┘
┌──────▼───────┐
│ validation │ validate_agent_credentials()
│ (2-phase) │ Phase 1: presence Phase 2: health check
└──────┬───────┘
┌─────────────▼─────────────┐
│ CredentialStore │ store.py
│ (cache + provider mgmt) │
└─────────────┬─────────────┘
┌───────────────────┼───────────────────┐
│ │ │
┌──────▼──────┐ ┌──────▼──────┐ ┌───────▼───────┐
│ EnvVarStorage│ │ Encrypted │ │ AdenCached │
│ (primary) │ │ FileStorage │ │ Storage │
└─────────────┘ │ (fallback) │ │ (Aden sync) │
└─────────────┘ └───────┬───────┘
┌───────▼───────┐
│AdenSyncProvider│
│+ AdenClient │
└───────────────┘
```
### Key Files
| Layer | File | Purpose |
|-------|------|---------|
| Models | `core/framework/credentials/models.py` | `CredentialObject`, `CredentialKey`, exception hierarchy |
| Storage | `core/framework/credentials/storage.py` | `EncryptedFileStorage`, `EnvVarStorage`, `CompositeStorage` |
| Store | `core/framework/credentials/store.py` | `CredentialStore` — cache, providers, refresh |
| Validation | `core/framework/credentials/validation.py` | `validate_agent_credentials()` — two-phase pre-flight check |
| Setup | `core/framework/credentials/setup.py` | `CredentialSetupSession` — interactive credential collection |
| Aden client | `core/framework/credentials/aden/client.py` | `AdenCredentialClient` — HTTP calls to api.adenhq.com |
| Aden provider | `core/framework/credentials/aden/provider.py` | `AdenSyncProvider` — refresh, sync, fetch |
| Aden storage | `core/framework/credentials/aden/storage.py` | `AdenCachedStorage` — local cache + Aden fallback |
| Specs | `tools/src/aden_tools/credentials/` | `CredentialSpec` per integration (env_var, health check, etc.) |
| Runner | `core/framework/runner/runner.py` | `_validate_credentials()` — agent startup gate |
| TUI | `core/framework/tui/screens/credential_setup.py` | `CredentialSetupScreen` — modal credential form |
| TUI app | `core/framework/tui/app.py` | `_show_credential_setup()`, `_load_and_switch_agent()` |
### Exception Hierarchy
```
CredentialError ← base (caught by runner + TUI)
├── CredentialDecryptionError ← corrupted/wrong-key .enc files
├── CredentialKeyNotFoundError ← key name not in credential
├── CredentialNotFoundError ← credential ID not found
├── CredentialRefreshError ← refresh failed (e.g., revoked OAuth)
└── CredentialValidationError ← schema/format invalid
```
---
## Scenario 1: User Supplies Correct Credential
### Flow
```
AgentRunner._setup()
→ _ensure_credential_key_env() # validation.py:16
│ Loads HIVE_CREDENTIAL_KEY, ADEN_API_KEY from shell config into os.environ
→ _validate_credentials() # runner.py:418
→ validate_agent_credentials(nodes) # validation.py:94
│ Phase 0: Aden pre-sync (if ADEN_API_KEY set)
│ → _presync_aden_tokens() # validation.py:50
│ → CredentialStore.with_aden_sync(auto_sync=True)
│ → For each aden_supported spec: get_key() → set os.environ
│ Build store:
│ EnvVarStorage (primary) + EncryptedFileStorage (fallback if HIVE_CREDENTIAL_KEY set)
│ Phase 1: Presence check
│ → store.is_available(cred_id)
│ → EnvVarStorage.load() → os.environ[env_var] → CredentialObject ✓
│ Result: NOT in missing list
│ Phase 2: Health check (if spec.health_check_endpoint set)
│ → check_credential_health(cred_name, value)
│ e.g., Anthropic: POST /v1/messages → 400 (key valid, request malformed) → valid=True
│ e.g., Brave: GET /search?q=test → 200 → valid=True
│ Result: NOT in invalid list
│ errors = [] → returns normally ✓
```
### What Happens
- Validation passes silently
- Agent loads and runs
- No files written, no user-visible output
- `CredentialStore._cache` populated (5-min TTL)
---
## Scenario 2: User Supplies Wrong Credential
### Flow
```
validate_agent_credentials(nodes)
│ Phase 1: Presence check
│ → store.is_available("anthropic")
│ → EnvVarStorage.load() → os.environ["ANTHROPIC_API_KEY"] = "wrong-key"
│ → Returns CredentialObject ✓ (value exists, content not validated)
│ Result: passes presence check, added to to_verify list
│ Phase 2: Health check
│ → check_credential_health("anthropic", credential_object)
│ → AnthropicHealthChecker: POST /v1/messages with x-api-key: "wrong-key"
│ → Response: 401 Unauthorized
│ → HealthCheckResult(valid=False, message="API key is invalid")
│ → Added to invalid list, cred_name added to failed_cred_names
│ CredentialError raised:
│ "Invalid or expired credentials:
│ ANTHROPIC_API_KEY for event_loop nodes — Anthropic API key is invalid
│ Get a new key at: https://console.anthropic.com/settings/keys"
│ exc.failed_cred_names = ["anthropic"]
```
### TUI Path (non-interactive)
```
_load_and_switch_agent() # app.py:356
except CredentialError as e: # app.py:382
→ _show_credential_setup(agent_path, e) # app.py:404
→ build_setup_session_from_error(e) # validation.py:253
→ failed_cred_names = ["anthropic"]
→ Creates MissingCredential for anthropic
→ push_screen(CredentialSetupScreen)
```
### CLI Path (interactive with TTY)
```
_validate_credentials() # runner.py:418
except CredentialError as e: # runner.py:440
→ print(str(e), file=sys.stderr)
→ session = build_setup_session_from_error(e)
→ session.run_interactive() # Terminal prompts
→ validate_agent_credentials(nodes) # Re-validate
```
### What User Sees
- TUI: Credential setup modal with the invalid credential's input field
- CLI: Error message printed, interactive prompts
### Silent Failure Risk
If `check_credential_health()` itself throws (network timeout, DNS failure, import error),
it's caught at `validation.py:231`:
```python
except Exception as exc:
logger.debug("Health check for %s failed: %s", cred_name, exc)
```
The credential is NOT added to `invalid`. **Agent starts with a bad key.** Only `logger.debug`
records the issue.
---
## Scenario 3: Credential Expired But Can Be Refreshed
Applies to OAuth2 credentials (Google, HubSpot, etc.) managed via AdenSyncProvider.
### Flow: Token Refresh During Runtime
```
CredentialStore.get_credential(cred_id, refresh_if_needed=True) # store.py:176
│ Check cache → cached credential found
│ → _should_refresh(cached) # store.py:442
│ → AdenSyncProvider.should_refresh(credential) # provider.py:238
│ → access_key = credential.keys["access_token"]
│ → datetime.now(UTC) >= (expires_at - 5min buffer)
│ → Returns True (within refresh window)
│ → _refresh_credential(cached) # store.py:456
│ → AdenSyncProvider.refresh(credential) # provider.py:151
│ → client.request_refresh(credential.id) # client.py:356
│ → POST /v1/credentials/{id}/refresh
│ → Server refreshes OAuth token, returns new access_token
│ → _update_credential_from_aden(credential, response)
│ → Updates access_token value + expires_at
│ → storage.save(refreshed) # Writes new .enc file
│ → _add_to_cache(refreshed) # Updates in-memory cache
│ → Returns refreshed credential ✓
```
### Flow: Expired Token Caught During Validation
```
validate_agent_credentials(nodes)
│ Phase 0: _presync_aden_tokens()
│ → CredentialStore.with_aden_sync(auto_sync=True)
│ → provider.sync_all() fetches fresh tokens from Aden
│ → Fresh token set in os.environ ✓
│ Phase 2: Health check with fresh token → valid=True ✓
```
### What Happens
- Refresh is transparent to the user
- New token written to `~/.hive/credentials/credentials/{id}.enc`
- In-memory cache updated
- Logged: `INFO: Refreshed credential '{id}' via Aden server`
---
## Scenario 4: Credential Expired and Cannot Be Refreshed
OAuth refresh token is revoked (user disconnected integration on hive.adenhq.com, or
the refresh token itself expired).
### Flow: Refresh Attempt
```
AdenSyncProvider.refresh(credential) # provider.py:151
→ client.request_refresh(credential.id) # client.py:356
→ POST /v1/credentials/{id}/refresh
→ Response: 400 {"error": "refresh_failed",
│ "requires_reauthorization": true,
│ "reauthorization_url": "https://..."}
→ AdenRefreshError raised # client.py:297
except AdenRefreshError as e: # provider.py:186
→ logger.error("Aden refresh failed for '{id}': ...")
→ raise CredentialRefreshError(
"Integration '{id}' requires re-authorization. Visit: ..."
)
```
### What CredentialStore Does
```
CredentialStore._refresh_credential(credential) # store.py:456
except CredentialRefreshError as e: # store.py:474
→ logger.error("Failed to refresh credential '{id}': ...")
→ return credential ← RETURNS STALE/EXPIRED CREDENTIAL!
```
**BUG: Silent failure.** The store returns the expired credential without raising.
The caller gets an expired token. Downstream API calls fail with 401.
### During Validation
If validation runs health check on the expired token:
```
check_credential_health() → 401 → valid=False
→ Added to invalid list → CredentialError raised
→ TUI shows credential setup screen
```
### Gap: Token Expires After Validation
If the token expires **during agent execution** (after validation passed):
- Refresh fails silently (returns stale credential)
- Tool call gets 401 from downstream API
- LLM sees tool error, no framework-level recovery
---
## Scenario 5: Credential Store File Sabotaged (Wrong Content)
File `~/.hive/credentials/credentials/{id}.enc` replaced with valid Fernet-encrypted
content encoding wrong JSON (e.g., `{"bad": "data"}`).
### Flow
```
EncryptedFileStorage.load(credential_id) # storage.py:193
→ fernet.decrypt(encrypted) # Succeeds (valid Fernet)
→ json.loads(decrypted) # Succeeds (valid JSON)
→ _deserialize_credential(data) # storage.py:252
→ CredentialObject.model_validate({"bad": "data"})
```
### Sub-case A: Missing `id` field
```
CredentialObject.model_validate({"bad": "data"})
→ Pydantic ValidationError: "id - Field required"
→ NOT caught by EncryptedFileStorage's try/except (only covers decrypt + json.loads)
→ Propagates up uncaught
```
**TUI**: Caught by generic `except Exception` in `_load_and_switch_agent()` (app.py:389):
```
self.notify("Failed to load agent: 1 validation error for CredentialObject...", severity="error")
```
User sees generic error notification. NOT a credential setup screen. **Not actionable.**
**CLI**: Unhandled traceback.
### Sub-case B: Valid `id` but wrong/empty keys
```
CredentialObject.model_validate({"id": "my_cred", "keys": {}})
→ Valid CredentialObject with keys={} (Pydantic extra="allow", keys defaults to {})
→ store.is_available() → get_credential() returns CredentialObject
→ But get() / get_key() returns None → is_available returns False
→ Treated as "missing" credential
```
User sees credential setup screen as if the credential was never configured.
**The actual cause (sabotaged file) is hidden.**
---
## Scenario 6: Credential Store File Corrupted (Binary Garbage)
File `~/.hive/credentials/credentials/{id}.enc` contains random binary data.
### Flow
```
EncryptedFileStorage.load(credential_id) # storage.py:193
→ fernet.decrypt(binary_garbage)
→ Raises cryptography.fernet.InvalidToken
→ Caught by except Exception: # storage.py:210
→ raise CredentialDecryptionError(
"Failed to decrypt credential '{id}': InvalidToken"
)
```
### Propagation
```
CredentialDecryptionError (subclass of CredentialError)
→ CompositeStorage.load(): NOT caught → propagates
→ CredentialStore.get_credential(): NOT caught → propagates
→ validate_agent_credentials() → propagates out entirely
```
**TUI** (app.py:382):
```python
except CredentialError as e: # CATCHES CredentialDecryptionError
self._show_credential_setup(str(agent_path), credential_error=e)
```
Shows credential setup screen! But `CredentialDecryptionError` does NOT have
`failed_cred_names` attribute → `getattr(e, "failed_cred_names", [])` returns `[]`
→ session falls back to `from_agent_path()` detection.
User sees credential setup screen as if credential is missing.
**Corruption is hidden.** Re-entering the credential overwrites the corrupted file.
### CompositeStorage Bug
If `CompositeStorage(primary=EnvVarStorage, fallbacks=[EncryptedFileStorage])` is used,
the storage tries primary first. But if `EncryptedFileStorage` is a fallback and
the .enc file is corrupted:
```
CompositeStorage.load()
→ primary (EnvVarStorage) → env var IS set → returns CredentialObject ✓
```
The corrupted fallback is never touched. **This case works fine.**
But if the storage order is reversed (encrypted primary, env fallback):
```
CompositeStorage.load()
→ primary (EncryptedFileStorage) → CredentialDecryptionError
→ NOT caught → propagates ← BUG: fallback never tried
```
The exception from primary propagates BEFORE checking the fallback.
**A corrupted .enc file blocks access even when the env var has a valid value.**
---
## Scenario 7: ADEN_API_KEY Set But Vendor OAuth Not Authorized
User has valid `ADEN_API_KEY`. Agent needs HubSpot/Google. User has NOT connected
that integration on hive.adenhq.com.
### Flow
```
validate_agent_credentials(nodes)
│ Phase 0: _presync_aden_tokens()
│ → CredentialStore.with_aden_sync(auto_sync=True)
│ → provider.sync_all(store)
│ → client.list_integrations() # GET /v1/credentials
│ → HubSpot NOT in response (never connected)
│ → Only connected integrations synced
│ → For hubspot spec: get_key("hubspot", "access_token")
│ → AdenCachedStorage.load("hubspot")
│ → _provider_index.get("hubspot") → None (not synced)
│ → _load_by_id("hubspot")
│ → local: None (not cached)
│ → aden: fetch_from_aden("hubspot")
│ → GET /v1/credentials/hubspot → 404
│ → AdenNotFoundError caught → returns None
│ → Returns None
│ → get_key returns None
│ → os.environ["HUBSPOT_ACCESS_TOKEN"] NOT set
│ Phase 1: Presence check
│ → _check_credential(hubspot_spec, "hubspot", "hubspot tools")
│ → store.is_available("hubspot") → False
│ → has_aden_key=True, aden_supported=True, direct_api_key_supported=False
│ → Goes into aden_not_connected list (NOT failed_cred_names)
│ CredentialError raised:
│ "Aden integrations not connected (ADEN_API_KEY is set but OAuth tokens unavailable):
│ HUBSPOT_ACCESS_TOKEN for hubspot tools
│ Connect this integration at hive.adenhq.com first."
│ exc.failed_cred_names = [] ← empty!
```
### TUI Behavior
```
_show_credential_setup(agent_path, credential_error=e)
→ build_setup_session_from_error(e)
→ failed_cred_names = [] → falls back to from_agent_path()
→ detect_missing_credentials_from_nodes() finds hubspot missing
→ session.missing = [MissingCredential(hubspot, aden_supported=True, ...)]
→ NOT empty → CredentialSetupScreen pushed
```
Setup screen shows ADEN_API_KEY input (already set). User clicks "Save & Continue":
```
_save_credentials()
→ ADEN_API_KEY already in env → configured += 1
→ _sync_aden_credentials()
→ provider.sync_all() → hubspot still not connected → synced=0
→ Notification: "No active integrations found in Aden."
→ For hubspot: store.is_available("hubspot") → False
→ Notification: "hubspot (id='hubspot') not found in Aden."
→ configured > 0 → dismiss(True)
```
TUI retries `_do_load_agent()` → validation fails again → **LOOP.**
### What User Sees
1. Setup screen appears, ADEN_API_KEY field shown
2. User clicks Save
3. Warning: "hubspot not found in Aden. Connect this integration at hive.adenhq.com first."
4. Screen dismisses (configured=1 from ADEN_API_KEY)
5. Agent reload fails → setup screen appears again
6. Repeat forever
### Root Cause
`configured += 1` fires when ADEN_API_KEY is saved, even though the actual needed
credential (hubspot OAuth token) was NOT obtained. The screen dismisses with "success"
but the agent still can't load.
---
## Known Silent Failure Points
| # | Location | What Happens | Risk |
|---|----------|-------------|------|
| 1 | `validation.py:231` | `check_credential_health()` throws → `logger.debug()` → credential treated as valid | Agent starts with bad key |
| 2 | `store.py:474-476` | `CredentialRefreshError` caught → returns stale credential | Tool calls fail with 401 at runtime |
| 3 | `store.py:706-708` | `with_aden_sync()` catches all Exception → falls back to local-only store silently | Aden sync failure invisible |
| 4 | `provider.py:312-313` | Individual integration sync fails → `logger.warning()` → skipped | Integration silently missing |
| 5 | `credential_setup.py:262-263` | `_persist_to_local_store()``except Exception: pass` | Credential lost on restart |
| 6 | `storage.py:489-501` | `CompositeStorage.load()` doesn't catch primary storage exceptions | Corrupted .enc blocks env var fallback |
| 7 | `validation.py:63-65` | `_presync_aden_tokens()` catches all Exception → `logger.warning()` | Aden tokens not refreshed, stale values used |
---
## Storage Priority Order
### During Validation (`validate_agent_credentials`)
```
1. os.environ (via EnvVarStorage) ← WINS if set
2. ~/.hive/credentials/credentials/*.enc ← fallback (only if HIVE_CREDENTIAL_KEY set)
```
### During Runtime (`CredentialStoreAdapter.default()`)
```
1. EncryptedFileStorage ← primary (if HIVE_CREDENTIAL_KEY set)
2. EnvVarStorage ← fallback
3. AdenSyncProvider ← if ADEN_API_KEY set, auto-refresh on access
```
**Note: validation and runtime use DIFFERENT storage priority orders.** Validation
prefers env vars; runtime prefers encrypted store. This means a credential can pass
validation (from env) but fail at runtime (encrypted store has stale value and env
var was only set in the validation process, not persisted).
### During TUI Credential Setup (`_sync_aden_credentials`)
```
1. AdenSyncProvider.sync_all() ← fetches from Aden API
2. AdenCachedStorage ← local encrypted cache
(no EnvVarStorage in this path)
```
---
## File Locations on Disk
```
~/.hive/
credentials/
credentials/ # EncryptedFileStorage base
{credential_id}.enc # Fernet-encrypted JSON
key.txt # HIVE_CREDENTIAL_KEY (generated if missing)
configuration.json # Global config
```
### .enc File Format (decrypted)
```json
{
"id": "hubspot",
"credential_type": "oauth2",
"keys": {
"access_token": {
"name": "access_token",
"value": "ya29.a0ARrdaM...",
"expires_at": "2025-01-15T12:00:00+00:00"
},
"_aden_managed": {
"name": "_aden_managed",
"value": "true"
},
"_integration_type": {
"name": "_integration_type",
"value": "hubspot"
}
},
"provider_id": "aden_sync",
"auto_refresh": true
}
```
The `_integration_type` key is used by `AdenCachedStorage._index_provider()` to map
provider names (e.g., "hubspot") to hash-based credential IDs from Aden.
+12 -3
View File
@@ -74,8 +74,9 @@ The setup script performs these actions:
1. Checks Python version (3.11+)
2. Installs `framework` package from `/core` (editable mode)
3. Installs `aden_tools` package from `/tools` (editable mode)
4. Fixes package compatibility (upgrades openai for litellm)
5. Verifies all installations
4. Prompts for a default LLM provider, including Hive LLM and OpenRouter
5. Fixes package compatibility (upgrades openai for litellm)
6. Verifies all installations
### API Keys (Optional)
@@ -85,6 +86,8 @@ For running agents with real LLMs:
# Add to your shell profile (~/.bashrc, ~/.zshrc, etc.)
export ANTHROPIC_API_KEY="your-key-here"
export OPENAI_API_KEY="your-key-here" # Optional
export OPENROUTER_API_KEY="your-key-here" # Optional, for OpenRouter models
export HIVE_API_KEY="your-key-here" # Optional, for Hive LLM
export BRAVE_SEARCH_API_KEY="your-key-here" # Optional, for web search tool
```
@@ -92,8 +95,12 @@ Get API keys:
- **Anthropic**: [console.anthropic.com](https://console.anthropic.com/)
- **OpenAI**: [platform.openai.com](https://platform.openai.com/)
- **OpenRouter**: [openrouter.ai/keys](https://openrouter.ai/keys)
- **Hive LLM**: [Hive Discord](https://discord.com/invite/hQdU7QDkgR)
- **Brave Search**: [brave.com/search/api](https://brave.com/search/api/)
For OpenRouter and Hive LLM configuration snippets, see [configuration.md](./configuration.md).
### Install Claude Code Skills
```bash
@@ -177,7 +184,7 @@ hive/ # Repository root
│ │ ├── builder/ # Agent builder utilities
│ │ ├── credentials/ # Credential management
│ │ ├── graph/ # GraphExecutor - executes node graphs
│ │ ├── llm/ # LLM provider integrations (Anthropic, OpenAI, etc.)
│ │ ├── llm/ # LLM provider integrations (Anthropic, OpenAI, OpenRouter, Hive, etc.)
│ │ ├── mcp/ # MCP server integration
│ │ ├── runner/ # AgentRunner - loads and runs agents
| | ├── observability/ # Structured logging - human-readable and machine-parseable tracing
@@ -633,6 +640,8 @@ def my_custom_tool(param1: str, param2: int) -> Dict[str, Any]:
# Add to your shell profile (~/.bashrc, ~/.zshrc, etc.)
export ANTHROPIC_API_KEY="your-key-here"
export OPENAI_API_KEY="your-key-here"
export OPENROUTER_API_KEY="your-key-here"
export HIVE_API_KEY="your-key-here"
export BRAVE_SEARCH_API_KEY="your-key-here"
# Or create .env file (not committed to git)
+4 -30
View File
@@ -217,36 +217,6 @@ Follow the prompts to:
This step establishes the core concepts and rules needed before building an agent.
### 4. Apply Agent Patterns
```
claude> pattern guidance
```
Follow the prompts to:
1. Apply best-practice agent design patterns
2. Add pause/resume flows for multi-turn interactions
3. Improve robustness with routing, fallbacks, and retries
4. Avoid common anti-patterns during agent construction
This step helps optimize agent design before final testing.
### 5. Test Your Agent
```
claude> test workflow
```
Follow the prompts to:
1. Generate test guidelines for constraints and success criteria
2. Write agent tests directly under `exports/{agent}/tests/`
3. Run goal-based evaluation tests
4. Debug failing tests and iterate on agent improvements
This step verifies that the agent meets its goals before production use.
## Troubleshooting
### "externally-managed-environment" error (PEP 668)
@@ -515,8 +485,12 @@ Add to `.vscode/settings.json`:
```bash
export ANTHROPIC_API_KEY="sk-ant-..."
export OPENROUTER_API_KEY="your-openrouter-key" # Optional
export HIVE_API_KEY="your-hive-key" # Optional
```
Quickstart also supports selecting OpenRouter and Hive LLM interactively. See [configuration.md](./configuration.md) for the full configuration examples.
### Optional Configuration
```bash
+11 -3
View File
@@ -17,7 +17,7 @@ The fastest way to get started:
```bash
# 1. Clone the repository
git clone https://github.com/adenhq/hive.git
git clone https://github.com/aden-hive/hive.git
cd hive
# 2. Run automated setup
@@ -31,7 +31,7 @@ uv run python -c "import framework; import aden_tools; print('✓ Setup complete
```powershell
# 1. Clone the repository
git clone https://github.com/adenhq/hive.git
git clone https://github.com/aden-hive/hive.git
cd hive
# 2. Run automated setup
@@ -166,6 +166,8 @@ For running agents with real LLMs:
# Add to your shell profile (~/.bashrc, ~/.zshrc, etc.)
export ANTHROPIC_API_KEY="your-key-here"
export OPENAI_API_KEY="your-key-here" # Optional
export OPENROUTER_API_KEY="your-key-here" # Optional, for OpenRouter models
export HIVE_API_KEY="your-key-here" # Optional, for Hive LLM
export BRAVE_SEARCH_API_KEY="your-key-here" # Optional, for web search
```
@@ -173,8 +175,12 @@ Get your API keys:
- **Anthropic**: [console.anthropic.com](https://console.anthropic.com/)
- **OpenAI**: [platform.openai.com](https://platform.openai.com/)
- **OpenRouter**: [openrouter.ai/keys](https://openrouter.ai/keys)
- **Hive LLM**: [Hive Discord](https://discord.com/invite/hQdU7QDkgR)
- **Brave Search**: [brave.com/search/api](https://brave.com/search/api/)
Quickstart can configure OpenRouter and Hive LLM for you interactively. See [configuration.md](./configuration.md) for the full configuration examples.
## Testing Your Agent
```bash
@@ -218,6 +224,8 @@ uv pip install -e .
```bash
# Verify API key is set
echo $ANTHROPIC_API_KEY
echo $OPENROUTER_API_KEY
echo $HIVE_API_KEY
```
@@ -232,6 +240,6 @@ pip uninstall -y framework tools
## Getting Help
- **Documentation**: Check the `/docs` folder
- **Issues**: [github.com/adenhq/hive/issues](https://github.com/adenhq/hive/issues)
- **Issues**: [github.com/adenhq/hive/issues](https://github.com/aden-hive/hive/issues)
- **Discord**: [discord.com/invite/MXE49hrKDk](https://discord.com/invite/MXE49hrKDk)
- **Build Agents**: Use the coder-tools workflow to create agents
-110
View File
@@ -1,110 +0,0 @@
# Hive Coder: Meta-Agent Integration Plan
## Problem
The hive_coder agent currently has 7 file I/O tools (`read_file`, `write_file`, `edit_file`, `list_directory`, `search_files`, `run_command`, `undo_changes`) in `tools/coder_tools_server.py`. It can write agent packages but is **not integrated into the Hive ecosystem**:
1. **No dynamic tool discovery** — It references a static list of hive-tools in `reference/framework_guide.md`. It can't discover what MCP tools are actually available or what parameters they accept.
2. **No runtime observability** — It can't inspect sessions, checkpoints, or logs from agents it builds. When something goes wrong, the user has to manually dig through files.
3. **No test execution** — It can't run an agent's test suite structurally (it could use `run_command` with raw pytest, but has no structured test parsing).
## Solution
Add 8 new tools to `coder_tools_server.py` that give hive_coder deep integration with the Hive framework. Update the system prompt to teach the LLM when and how to use these meta-agent capabilities.
---
## New Tools
### 1. Tool Discovery
**`discover_mcp_tools(server_config_path?)`**
Connect to any MCP server and list all available tools with full schemas. Uses `framework.runner.mcp_client.MCPClient` — the same client the runtime uses. Reads a `mcp_servers.json` file (defaults to hive-tools), connects to each server, calls `list_tools()`, returns tool names + descriptions + input schemas, then disconnects.
This replaces the static tools reference. The LLM now discovers tools dynamically before designing an agent.
### 2. Agent Inventory
**`list_agents()`**
Scan `exports/` for agent packages and `~/.hive/agents/` for runtime data. Returns agent names, descriptions (from `__init__.py`), and session counts. Gives the LLM awareness of what already exists.
### 3-7. Session & Checkpoint Inspection
Ported from the former `agent_builder_server.py`. Pure filesystem reads — JSON + pathlib, zero framework imports.
| Tool | Purpose |
|------|---------|
| `list_agent_sessions(agent_name, status?, limit?)` | List sessions, filterable by status |
| `list_agent_checkpoints(agent_name, session_id)` | List checkpoints for debugging |
| `get_agent_checkpoint(agent_name, session_id, checkpoint_id?)` | Load a checkpoint's full state |
**Key difference from the old agent-builder server:** These tools accept `agent_name` (e.g. `"deep_research_agent"`) instead of raw `agent_work_dir` paths. They resolve to `~/.hive/agents/{agent_name}/` internally. Friendlier for the LLM.
### 8. Test Execution
**`run_agent_tests(agent_name, test_types?, fail_fast?)`**
Ported from the former `agent_builder_server.py`. Runs pytest on an agent's test suite, sets PYTHONPATH automatically, parses output into structured results (passed/failed/skipped counts, per-test status, failure details).
---
## Files to Modify
### `tools/coder_tools_server.py` (~400 new lines)
Add all 8 tools after the existing `undo_changes` tool:
```
# ── Meta-agent: Tool discovery ────────────────────────────────
# discover_mcp_tools()
# ── Meta-agent: Agent inventory ───────────────────────────────
# list_agents()
# ── Meta-agent: Session & checkpoint inspection ───────────────
# _resolve_hive_agent_path(), _read_session_json(), _scan_agent_sessions(), _truncate_value()
# list_agent_sessions(), list_agent_checkpoints(), get_agent_checkpoint()
# list_agent_checkpoints(), get_agent_checkpoint()
# ── Meta-agent: Test execution ────────────────────────────────
# run_agent_tests()
```
### `exports/hive_coder/nodes/__init__.py`
- Add 8 new tool names to the `tools` list
- Rewrite system prompt "Tools Available" section with meta-agent tools
- Add "Meta-Agent Capabilities" section teaching:
- Tool discovery before designing agents
- Post-build test execution
- Debugging via session/checkpoint inspection
- Agent awareness via `list_agents()`
### `exports/hive_coder/agent.py`
- Update `identity_prompt` to mention dynamic tool discovery and runtime observability
- Add `dynamic-tool-discovery` constraint to the goal
### `exports/hive_coder/reference/framework_guide.md`
Replace static tools list with a note to use `discover_mcp_tools()` instead.
---
## What's NOT in Scope (deferred to v2)
- **Agent notifications / webhook listener** — Requires always-on listener architecture
- **`compare_agent_checkpoints`** — LLM can compare by reading two checkpoints sequentially
- **Runtime log query tools** — Available in hive-tools MCP; `run_command` can access them now
---
## Verification
1. MCP server starts with all 15 tools (7 existing + 8 new)
2. `discover_mcp_tools()` connects to hive-tools and returns real tool schemas
3. Agent validation passes (`default_agent.validate()`)
4. Session tools work against existing data in `~/.hive/agents/`
5. Smoke test: launch in TUI, ask it to discover tools
-37
View File
@@ -1,37 +0,0 @@
# Local API key credentials lack feature parity with Aden OAuth credentials
## Summary
The credential tester only surfaces accounts synced via Aden OAuth (requires `ADEN_API_KEY`). Users who authenticate services with a direct API key — Brave Search, GitHub, Exa, Google Maps, Stripe, Telegram, and many others — have no way to list, manage, or test those credentials through the same interface.
## Problem
Local API key credentials are completely flat today:
- **No namespace** — one env var per service (`BRAVE_SEARCH_API_KEY`), no aliases, no multi-account support
- **No identity metadata** — no way to record who owns a key (email, username, workspace)
- **No status tracking** — no "active / failed / unknown" state
- **Not visible in credential tester** — the account picker only calls the Aden API; it silently shows nothing if `ADEN_API_KEY` is absent
- **No management surface** — no list/add/delete/validate flow for API keys
Aden credentials have all of this: `integration_id`, alias, identity, status, health-check-on-sync, and a full listing API.
## Affected credentials (local-only by default)
Brave Search, Exa Search, Google Search (CSE), SerpAPI, GitHub, Google Maps, Telegram, Apollo, Stripe, Razorpay, Cal.com, BigQuery, GCP Vision, Resend, and more.
## Expected behavior
- Running the credential tester should surface **all** configured credentials — Aden-synced and local API keys together, in the same account picker
- Local API key accounts should support aliases (`work`, `personal`) so users can store multiple keys per service
- Identity metadata (username, email, workspace) should be extracted automatically via health check when a key is saved
- A status badge (`active` / `failed` / `unknown`) should indicate whether the key was last verified successfully
- The TUI should provide an "Add Local Credential" screen with a live health check
- The MCP `store_credential` / `list_stored_credentials` / `delete_stored_credential` tools should support aliases; a new `validate_credential` tool should allow re-checking a stored key at any time
## Root cause (bonus bug)
Even credentials configured with the existing `store_credential` MCP tool are invisible in the credential tester because:
1. `_list_env_fallback_accounts()` only checked env vars — it missed credentials stored in `EncryptedFileStorage` using the old flat format (`brave_search`, no alias)
2. `_activate_local_account()` early-returned for `alias == "default"`, assuming the env var was already set — but old flat encrypted credentials are not in `os.environ`
-594
View File
@@ -1,594 +0,0 @@
# Server & CLI Architecture: Shared Runtime Primitives
## Executive Summary
The `hive serve` HTTP server and the CLI commands (`hive run`, `hive shell`, `hive tui`) are two access layers built on top of the **same runtime primitives**. There is no separate "server runtime" — the HTTP server is a thin REST/SSE translation layer that delegates every operation to the same `AgentRunner`, `AgentRuntime`, `GraphExecutor`, and storage subsystems that the CLI uses directly.
---
## Architecture Overview
```mermaid
flowchart TB
subgraph Access["Access Layer"]
direction LR
subgraph CLI["CLI Access"]
Run["hive run"]
Shell["hive shell"]
TUI["hive tui"]
end
subgraph HTTP["HTTP Access (hive serve)"]
REST["REST Endpoints<br/>(aiohttp routes)"]
SSE["SSE Event Stream"]
SPA["Frontend SPA"]
end
end
subgraph Bridge["Server Bridge Layer"]
AM["AgentManager<br/>Multi-agent slot lifecycle"]
end
subgraph Core["Shared Runtime Core"]
AR["AgentRunner<br/>Load, validate, run agents"]
ART["AgentRuntime<br/>Multi-entry-point orchestration"]
GE["GraphExecutor<br/>Node execution, edge traversal"]
end
subgraph Storage["Shared Storage"]
SS["SessionStore"]
CS["CheckpointStore"]
RL["RuntimeLogger<br/>L1/L2/L3 logs"]
SM["SharedMemory"]
end
Run --> AR
Shell --> AR
TUI --> AR
REST --> AM
SSE --> AM
AM --> AR
AR --> ART
ART --> GE
GE --> SS
GE --> CS
GE --> RL
GE --> SM
```
### Key Insight
The only component unique to the HTTP server is `AgentManager` — a thin lifecycle wrapper that holds multiple `AgentSlot` instances concurrently. Each slot contains the **exact same objects** the CLI creates:
```python
@dataclass
class AgentSlot:
id: str
agent_path: Path
runner: AgentRunner # Same as CLI
runtime: AgentRuntime # Same as CLI
info: AgentInfo # Same as CLI
loaded_at: float
```
---
## The Shared Runtime Stack
### Layer 1: AgentRunner
The entry point for loading and running any agent, regardless of access mode.
```python
# CLI usage (hive run)
runner = AgentRunner.load("exports/my-agent", model="claude-sonnet-4-6")
result = await runner.run(input_data={"query": "hello"})
# Server usage (identical call inside AgentManager.load_agent)
runner = AgentRunner.load(agent_path, model=model, interactive=False)
```
**Responsibilities:**
- Load agents from `agent.json` or `agent.py`
- Discover tools from `tools.py` and `mcp_servers.json`
- Validate credentials before execution
- Provide `AgentInfo` and `ValidationResult` inspection
### Layer 2: AgentRuntime
The orchestrator for concurrent, multi-entry-point execution.
```python
# Both CLI (TUI/shell) and server use the same runtime
runtime = runner._agent_runtime
await runtime.start()
# Triggering execution — identical call in both modes
exec_id = await runtime.trigger("default", {"query": "hello"})
# Injecting user input — identical call in both modes
await runtime.inject_input(node_id="chat", content="user message")
# Subscribing to events — CLI uses for TUI, server uses for SSE
sub_id = runtime.subscribe_to_events([EventType.CLIENT_OUTPUT_DELTA], handler)
```
### Layer 3: GraphExecutor
Executes the agent graph node-by-node. Completely unaware of whether it was invoked from CLI or HTTP.
**Responsibilities:**
- Node execution following `GraphSpec` edges
- Edge condition evaluation and routing
- `SharedMemory` management across nodes
- Checkpoint creation for resumability
- HITL pause points at `client_facing` nodes
### Layer 4: Storage
All storage subsystems are shared — sessions, checkpoints, and logs written via CLI are readable via the HTTP server and vice versa.
```
~/.hive/agents/{agent_name}/
├── sessions/ # SessionStore
│ └── session_YYYYMMDD_HHMMSS_{uuid}/
│ ├── state.json # Session state
│ ├── conversations/ # Per-node EventLoop state
│ ├── artifacts/ # Large outputs
│ └── logs/ # L1/L2/L3 observability
│ ├── summary.json
│ ├── details.jsonl
│ └── tool_logs.jsonl
├── runtime_logs/ # RuntimeLogger
└── artifacts/ # Fallback storage
```
---
## HTTP Endpoint to Runtime Primitive Mapping
Every HTTP endpoint is a direct, thin delegation to a shared runtime method. No execution logic lives in the route handlers.
### Agent Lifecycle
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `POST /api/agents` | Load agent | `AgentRunner.load()``runtime.start()` |
| `DELETE /api/agents/{id}` | Unload agent | `runner.cleanup_async()` |
| `GET /api/agents/{id}` | Agent info | `runner.info()``AgentInfo` |
| `GET /api/agents/{id}/stats` | Statistics | Runtime metrics collection |
| `GET /api/agents/{id}/entry-points` | Entry points | `runtime.get_entry_points()` |
| `GET /api/agents/{id}/graphs` | List graphs | `runtime.list_graphs()` |
| `GET /api/discover` | Discover agents | Filesystem scan (same as `hive list`) |
### Execution Control
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `POST /api/agents/{id}/trigger` | Start execution | `runtime.trigger(entry_point_id, input_data)` |
| `POST /api/agents/{id}/chat` | Auto-route | `runtime.inject_input()` or `runtime.trigger()` |
| `POST /api/agents/{id}/inject` | Send user input | `runtime.inject_input(node_id, content)` |
| `POST /api/agents/{id}/resume` | Resume session | `runtime.trigger()` with `session_state` |
| `POST /api/agents/{id}/stop` | Pause execution | Cancels the execution task |
| `POST /api/agents/{id}/replay` | Replay checkpoint | Checkpoint restore → `runtime.trigger()` |
| `GET /api/agents/{id}/goal-progress` | Goal progress | `runtime.get_goal_progress()` |
### Event Streaming
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `GET /api/agents/{id}/events` | SSE stream | `runtime.subscribe_to_events()` |
Default event types streamed: `CLIENT_OUTPUT_DELTA`, `CLIENT_INPUT_REQUESTED`, `LLM_TEXT_DELTA`, `TOOL_CALL_STARTED`, `TOOL_CALL_COMPLETED`, `EXECUTION_STARTED`, `EXECUTION_COMPLETED`, `EXECUTION_FAILED`, `EXECUTION_PAUSED`, `NODE_LOOP_STARTED`, `NODE_LOOP_COMPLETED`, `EDGE_TRAVERSED`, `GOAL_PROGRESS`.
### Session Management
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `GET /api/agents/{id}/sessions` | List sessions | `SessionStore.list_sessions()` |
| `GET /api/agents/{id}/sessions/{sid}` | Session details | `SessionStore.read_state()` |
| `DELETE /api/agents/{id}/sessions/{sid}` | Delete session | `SessionStore.delete_session()` |
| `GET /api/agents/{id}/sessions/{sid}/checkpoints` | List checkpoints | `CheckpointStore.list_checkpoints()` |
| `POST /api/agents/{id}/sessions/{sid}/checkpoints/{cid}/restore` | Restore checkpoint | Checkpoint load → `runtime.trigger()` |
| `GET /api/agents/{id}/sessions/{sid}/messages` | Chat history | `ConversationStore` reads |
### Graph Inspection
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `GET /api/agents/{id}/graphs/{gid}/nodes` | List nodes | `GraphSpec` inspection |
| `GET /api/agents/{id}/graphs/{gid}/nodes/{nid}` | Node details | `GraphSpec` node lookup |
| `GET /api/agents/{id}/graphs/{gid}/nodes/{nid}/criteria` | Success criteria | Node criteria + judge verdicts |
### Logging
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `GET /api/agents/{id}/logs` | Agent logs | `RuntimeLogger` queries |
| `GET /api/agents/{id}/graphs/{gid}/nodes/{nid}/logs` | Node logs | `RuntimeLogger` node-scoped queries |
---
## What Differs Between CLI and HTTP
The differences are in the **access pattern**, not the runtime behavior.
| Concern | CLI | HTTP Server |
|---|---|---|
| **Multi-agent** | One runner per process | `AgentManager` holds N slots concurrently |
| **User input** | stdin (shell) / TUI widget | `POST /inject` or `POST /chat` |
| **Event streaming** | `subscribe_to_events()` → TUI update | Same subscription → SSE stream |
| **HITL approval** | `set_approval_callback()` + stdin | `CLIENT_INPUT_REQUESTED` event → `/inject` |
| **Agent lifecycle** | Process start → run → exit | Dynamic load/unload via REST calls |
| **Concurrency** | Sequential (one run at a time) | Async — multiple triggers, multiple agents |
| **Agent discovery** | `hive list` scans dirs | `GET /api/discover` scans dirs (same logic) |
| **Frontend** | Terminal / Textual TUI | React SPA served from `frontend/dist/` |
---
## The AgentManager Bridge
The only component unique to the HTTP server. It manages the lifecycle of multiple loaded agents within a single process.
```mermaid
flowchart LR
subgraph AgentManager
S1["Slot: support-agent<br/>runner + runtime + info"]
S2["Slot: research-agent<br/>runner + runtime + info"]
S3["Slot: code-agent<br/>runner + runtime + info"]
end
Load["POST /api/agents"] -->|"load_agent()"| AgentManager
Unload["DELETE /api/agents/{id}"] -->|"unload_agent()"| AgentManager
List["GET /api/agents"] -->|"list_agents()"| AgentManager
Get["GET /api/agents/{id}"] -->|"get_agent()"| AgentManager
Shutdown["Server shutdown"] -->|"shutdown_all()"| AgentManager
```
**Key design choices:**
- **Thread-safe** via `asyncio.Lock` — no race conditions during load/unload
- **Blocking I/O offloaded**`AgentRunner.load()` runs in `run_in_executor` to avoid blocking the event loop
- **Same pattern as TUI** — the comment in source explicitly notes this: `# Blocking I/O — load in executor (same as tui/app.py:362-368)`
---
## How the `/chat` Endpoint Auto-Routes
The `/chat` endpoint demonstrates the thin-wrapper pattern. It checks runtime state and delegates:
```
POST /api/agents/{id}/chat { "message": "hello" }
Is any node waiting for input?
│ │
YES NO
│ │
▼ ▼
runtime.inject_input() runtime.trigger()
│ │
▼ ▼
{ "status": "injected", { "status": "started",
"node_id": "..." } "execution_id": "..." }
```
This is the same decision a human makes in the shell — if the agent is waiting for input, provide it; otherwise start a new execution.
---
## Concurrent Judge & Queen: Multi-Graph Monitoring Primitives
The Worker Health Judge and Queen triage system introduce **secondary graphs** that run alongside a primary worker graph within the same `AgentRuntime`. They share the runtime's `EventBus` but have fully isolated storage. This section documents the new runtime primitives, EventBus events, data models, and storage layout they introduce.
### Architecture
```
One AgentRuntime (shared EventBus)
|
+-- Worker Graph (primary) trigger_type: manual
| Entry point: "start" -> worker node (event_loop, client_facing)
|
+-- Health Judge Graph (secondary) trigger_type: timer (2 min)
| Entry point: "health_check" -> judge node (event_loop, autonomous)
| isolation_level: isolated
| conversation_mode: continuous
|
+-- Queen Graph (secondary) trigger_type: event (worker_escalation_ticket)
Entry point: "ticket_receiver" -> ticket_triage node (event_loop)
isolation_level: isolated
```
### GraphScopedEventBus and Event Identity Fields
Every event carries four identity fields: `(graph_id, stream_id, node_id, execution_id)`.
- **`graph_id`** — Set automatically by `GraphScopedEventBus`, a public subclass of `EventBus` that stamps `graph_id` on every `publish()` call. All three components (worker, judge, queen) use a scoped bus so their events are distinguishable.
- **`stream_id`** — The entry point pipeline. Flows from `EntryPointSpec.id` through `ExecutionStream``GraphExecutor``NodeContext``EventLoopNode`.
- **`node_id`** — The graph node emitting the event.
- **`execution_id`** — UUID for a specific execution run, set by `ExecutionStream` and wired through `GraphExecutor``EventLoopNode` → all `emit_*` calls.
See [EVENT_TYPES.md](../core/framework/runtime/EVENT_TYPES.md) for the complete event type and schema reference.
### New EventBus Event Types
Two new events added to `EventType` enum:
#### `WORKER_ESCALATION_TICKET`
Emitted by the health judge's `emit_escalation_ticket` tool when the judge detects a sustained degradation pattern in the worker.
```python
EventBus.emit_worker_escalation_ticket(
stream_id: str, # Judge's stream ID (e.g. "judge")
node_id: str, # "judge"
ticket: dict, # Full EscalationTicket (see data model below)
execution_id: str | None = None,
)
```
**Event data payload:**
```json
{
"ticket": {
"ticket_id": "uuid",
"created_at": "ISO8601",
"worker_agent_id": "deep_research_agent",
"worker_session_id": "session_20260223_130234_a30df858",
"worker_node_id": "researcher",
"worker_graph_id": "deep-research-graph",
"severity": "high",
"cause": "Node has produced 18 consecutive RETRY verdicts with identical error...",
"judge_reasoning": "Steps unchanged across last 3 checks. Evidence shows repeated...",
"suggested_action": "Restart node with modified prompt or escalate to human.",
"recent_verdicts": ["RETRY", "RETRY", "RETRY", "CONTINUE", "RETRY"],
"total_steps_checked": 47,
"steps_since_last_accept": 18,
"stall_minutes": null,
"evidence_snippet": "Error: API returned 429 rate limit exceeded..."
}
}
```
**Subscribers:** Queen's `ticket_receiver` entry point (event-driven). HTTP server should forward via SSE for dashboard UIs.
#### `QUEEN_INTERVENTION_REQUESTED`
Emitted by the queen's `notify_operator` tool after triaging an escalation ticket and deciding the human operator should be involved.
```python
EventBus.emit_queen_intervention_requested(
stream_id: str, # Queen's stream ID
node_id: str, # "ticket_triage"
ticket_id: str, # References the original EscalationTicket
analysis: str, # Queen's 2-3 sentence analysis
severity: str, # "low" | "medium" | "high" | "critical"
queen_graph_id: str, # "queen"
queen_stream_id: str, # "queen"
execution_id: str | None = None,
)
```
**Event data payload:**
```json
{
"ticket_id": "uuid",
"analysis": "Worker is stuck in a rate-limit retry loop for 6+ minutes. Suggest pausing and retrying with backoff.",
"severity": "high",
"queen_graph_id": "queen",
"queen_stream_id": "queen"
}
```
**Subscribers:** TUI (shows non-disruptive overlay). HTTP server should forward via SSE.
### New Data Model: EscalationTicket
```python
# core/framework/runtime/escalation_ticket.py
class EscalationTicket(BaseModel):
ticket_id: str # Auto-generated UUID
created_at: str # Auto-generated ISO8601
# Worker identification
worker_agent_id: str # Agent name (e.g. "deep_research_agent")
worker_session_id: str # Session being monitored
worker_node_id: str # Primary graph's entry node
worker_graph_id: str # Primary graph ID
# Problem characterization (LLM-generated by judge)
severity: Literal["low", "medium", "high", "critical"]
cause: str # What the judge observed
judge_reasoning: str # Why the judge decided to escalate
suggested_action: str # Recommended intervention
# Evidence
recent_verdicts: list[str] # Last N verdicts (ACCEPT/RETRY/CONTINUE/ESCALATE)
total_steps_checked: int # Total log steps seen
steps_since_last_accept: int
stall_minutes: float | None # Wall-clock since last step (None if active)
evidence_snippet: str # Truncated recent LLM output
```
### Modified AgentRuntime APIs
The following existing methods gained a `graph_id` parameter to support multi-graph routing. When `graph_id=None` (default), the method targets the **active graph** (`active_graph_id`), falling back to the primary graph. Existing callers that pass no `graph_id` are unaffected.
| Method | New parameter | Notes |
|---|---|---|
| `trigger()` | `graph_id: str \| None = None` | Routes to the named graph's stream |
| `get_entry_points()` | `graph_id: str \| None = None` | Returns entry points for the specified graph |
| `get_stream()` | `graph_id: str \| None = None` | Resolves stream via active graph first |
| `get_execution_result()` | `graph_id: str \| None = None` | Looks up result in the graph's stream |
| `cancel_execution()` | `graph_id: str \| None = None` | Cancels execution in the graph's stream |
### New AgentRuntime APIs
| Method | Signature | Description |
|---|---|---|
| `get_active_graph()` | `-> GraphSpec` | Returns the `GraphSpec` for the currently active graph (used by TUI/chat routing) |
| `active_graph_id` (property) | `str` (get/set) | The graph that receives user input. Set by TUI when switching between worker and queen views |
| `get_active_streams()` | `-> list[dict]` | Returns metadata for every stream with active executions across all graphs. Each dict contains `graph_id`, `stream_id`, `entry_point_id`, `active_execution_ids`, `is_awaiting_input`, `waiting_nodes`. |
| `get_waiting_nodes()` | `-> list[dict]` | Flat list of all nodes currently blocked waiting for client input across all graphs/streams. Each dict contains `graph_id`, `stream_id`, `node_id`, `execution_id`. |
### New ExecutionStream APIs
| Method | Signature | Description |
|---|---|---|
| `get_waiting_nodes()` | `-> list[dict]` | Returns `[{"node_id": str, "execution_id": str}]` for every `EventLoopNode` with `_awaiting_input == True`. |
| `get_injectable_nodes()` | `-> list[dict]` | Returns `[{"node_id": str, "execution_id": str}]` for every node that supports message injection (has `inject_event` method). |
### Proposed HTTP Endpoints
These endpoints are not yet implemented. They expose the new multi-graph and monitoring primitives to the HTTP access layer, following the same thin-delegation pattern as existing endpoints.
#### Multi-Graph Control
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `POST /api/agents/{id}/graphs` | Load secondary graph | `runtime.add_graph(graph_id, graph, goal, entry_points)` |
| `DELETE /api/agents/{id}/graphs/{gid}` | Unload secondary graph | `runtime.remove_graph(graph_id)` (not yet implemented) |
| `GET /api/agents/{id}/graphs/{gid}/sessions` | List graph sessions | Graph-specific `SessionStore.list_sessions()` |
| `GET /api/agents/{id}/graphs/{gid}/sessions/{sid}` | Graph session details | Graph-specific `SessionStore.read_state()` |
| `PUT /api/agents/{id}/active-graph` | Switch active graph | `runtime.active_graph_id = graph_id` |
| `GET /api/agents/{id}/active-graph` | Get active graph | `runtime.active_graph_id` |
#### Stream Introspection
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `GET /api/agents/{id}/streams` | Active streams | `runtime.get_active_streams()` — all streams with active executions |
| `GET /api/agents/{id}/waiting-nodes` | Waiting nodes | `runtime.get_waiting_nodes()` — all nodes blocked on client input |
#### Worker Health Monitoring
| HTTP Endpoint | Method | Runtime Primitive |
|---|---|---|
| `GET /api/agents/{id}/health` | Health summary | Calls `get_worker_health_summary()` tool (reads worker session logs) |
| `GET /api/agents/{id}/escalations` | List escalation tickets | Query `WORKER_ESCALATION_TICKET` events from EventBus history |
| `GET /api/agents/{id}/escalations/{tid}` | Ticket details | Lookup specific ticket by `ticket_id` |
#### Event Streaming Additions
The SSE stream (`GET /api/agents/{id}/events`) should include the two new event types in its default set:
```
Default event types: ..., WORKER_ESCALATION_TICKET, QUEEN_INTERVENTION_REQUESTED
```
Clients can subscribe selectively:
```
GET /api/agents/{id}/events?types=worker_escalation_ticket,queen_intervention_requested
```
### Isolated Session Lifecycle for Secondary Graphs
Isolated entry points (`isolation_level="isolated"`) use **persistent sessions** — a single session is created on first trigger and reused for all subsequent triggers of the same entry point. This is critical for:
- **Timer-driven** entry points (health judge): one session across all timer ticks, so `conversation_mode="continuous"` works and the judge accumulates observations in its conversation history.
- **Event-driven** entry points (queen ticket receiver): one session across all received events, so the queen can reference prior triage decisions.
The session reuse is managed by the timer/event handler closures in `AgentRuntime`, which remember the first `execution_id` returned by `stream.execute()` and pass it as `resume_session_id` on all subsequent fires. The `GraphExecutor` detects the existing conversation store, resets the cursor (clearing stale outputs), and appends a transition marker so the LLM knows a new trigger arrived while the conversation thread carries forward.
### Secondary Graph Storage Layout
Secondary graphs have fully isolated storage under `graphs/{graph_id}/` to prevent any interference with the primary worker's sessions, logs, and conversations.
```
~/.hive/agents/{agent_name}/
+-- sessions/ # Primary graph only
| +-- session_YYYYMMDD_HHMMSS_{uuid}/
| +-- state.json
| +-- conversations/
| +-- logs/
+-- graphs/
| +-- judge/ # Health judge (secondary)
| | +-- sessions/
| | | +-- session_YYYYMMDD_HHMMSS_{uuid}/ # ONE persistent session
| | | +-- state.json
| | | +-- conversations/judge/ # Continuous conversation
| | | +-- logs/
| | | +-- tool_logs.jsonl
| | | +-- details.jsonl
| | +-- runtime_logs/
| +-- queen/ # Queen triage (secondary)
| +-- sessions/
| | +-- session_YYYYMMDD_HHMMSS_{uuid}/ # ONE persistent session
| | +-- state.json
| | +-- conversations/ticket_triage/
| | +-- logs/
| +-- runtime_logs/
+-- runtime_logs/ # Primary graph runtime logs
```
Each secondary graph gets its own `SessionStore` and `RuntimeLogStore` scoped to `graphs/{graph_id}/`. This is set up in `AgentRuntime.add_graph()`:
```python
graph_base = self._session_store.base_path / subpath # e.g. .../graphs/judge
graph_session_store = SessionStore(graph_base)
graph_log_store = RuntimeLogStore(graph_base / "runtime_logs")
```
### Worker Monitoring Tools
Three tools registered via `register_worker_monitoring_tools(registry, event_bus, storage_path)`. These are bound to the worker's EventBus and storage path at registration time.
| Tool | Used by | Description |
|---|---|---|
| `get_worker_health_summary(session_id?, last_n_steps?)` | Health Judge | Reads worker's `sessions/{id}/logs/tool_logs.jsonl`. Auto-discovers active session if `session_id` omitted. Returns JSON with `worker_agent_id`, `worker_graph_id`, `session_id`, `total_steps`, `recent_verdicts`, `steps_since_last_accept`, `stall_minutes`, `evidence_snippet`. |
| `emit_escalation_ticket(ticket_json)` | Health Judge | Validates JSON against `EscalationTicket` schema (Pydantic rejects partial tickets), then calls `EventBus.emit_worker_escalation_ticket()`. |
| `notify_operator(ticket_id, analysis, urgency)` | Queen | Calls `EventBus.emit_queen_intervention_requested()` so the TUI/frontend surfaces a notification. |
### Queen Lifecycle Tools
Four tools registered via `register_queen_lifecycle_tools(registry, worker_runtime, event_bus)`. These close over the worker's `AgentRuntime` to give the Queen control over the worker agent's lifecycle.
| Tool | Description |
|---|---|
| `start_worker(task)` | Trigger the worker's default entry point with a task description. Returns an `execution_id`. |
| `stop_worker()` | Cancel all active worker executions. Returns IDs of cancelled executions. |
| `get_worker_status()` | Check if the worker is idle, running, or waiting for input. Returns execution details and waiting node ID if applicable. Uses `stream.get_waiting_nodes()` for accurate detection. |
| `inject_worker_message(content)` | Send a message to the running worker agent by finding an injectable node via `stream.get_injectable_nodes()` and calling `stream.inject_input()`. |
### New File Reference
| Component | Path |
|---|---|
| EscalationTicket model | `core/framework/runtime/escalation_ticket.py` |
| Worker Health Judge graph | `core/framework/monitoring/judge.py` |
| Worker monitoring tools | `core/framework/tools/worker_monitoring_tools.py` |
| Queen lifecycle tools | `core/framework/tools/queen_lifecycle_tools.py` |
| Monitoring package init | `core/framework/monitoring/__init__.py` |
| Event types reference | `core/framework/runtime/EVENT_TYPES.md` |
---
## File Reference
| Component | Path |
|---|---|
| CLI entry point | `core/framework/runner/cli.py` |
| HTTP app factory | `core/framework/server/app.py` |
| Agent manager | `core/framework/server/agent_manager.py` |
| Agent routes | `core/framework/server/routes_agents.py` |
| Execution routes | `core/framework/server/routes_execution.py` |
| Event routes | `core/framework/server/routes_events.py` |
| Session routes | `core/framework/server/routes_sessions.py` |
| Graph routes | `core/framework/server/routes_graphs.py` |
| Log routes | `core/framework/server/routes_logs.py` |
| SSE helper | `core/framework/server/sse.py` |
| AgentRunner | `core/framework/runner/runner.py` |
| AgentRuntime | `core/framework/runtime/agent_runtime.py` |
| GraphExecutor | `core/framework/graph/executor.py` |
| SessionStore | `core/framework/storage/session_store.py` |
| CheckpointStore | `core/framework/storage/checkpoint_store.py` |
| Runtime logger | `core/framework/runtime/core.py` |
| EventBus | `core/framework/runtime/event_bus.py` |
| ExecutionStream | `core/framework/runtime/execution_stream.py` |
| GraphScopedEventBus | `core/framework/runtime/execution_stream.py` |
| EscalationTicket | `core/framework/runtime/escalation_ticket.py` |
| Queen lifecycle tools | `core/framework/tools/queen_lifecycle_tools.py` |
| Worker monitoring tools | `core/framework/tools/worker_monitoring_tools.py` |
| Health Judge graph | `core/framework/monitoring/judge.py` |
| Event types reference | `core/framework/runtime/EVENT_TYPES.md` |
+290
View File
@@ -0,0 +1,290 @@
# Agent Skills User Guide
This guide covers how to use, create, and manage Agent Skills in the Hive framework. Agent Skills follow the open [Agent Skills standard](https://agentskills.io) — skills written for Claude Code, Cursor, or other compatible agents work in Hive unchanged.
## What are skills?
Skills are folders containing a `SKILL.md` file that teaches an agent how to perform a specific task. They can also bundle scripts, templates, and reference materials. Skills are loaded on demand — the agent sees a lightweight catalog at startup and pulls in full instructions only when relevant.
## Quick start
### Install a skill
Drop a skill folder into one of the discovery directories:
```bash
# Project-level (shared with the repo)
mkdir -p .hive/skills/my-skill
cat > .hive/skills/my-skill/SKILL.md << 'EOF'
---
name: my-skill
description: Does X when the user asks about Y.
---
# My Skill
Step-by-step instructions for the agent...
EOF
```
The agent will discover it automatically on the next session.
### List discovered skills
```bash
hive skill list
```
Output groups skills by scope:
```
PROJECT SKILLS
────────────────────────────────────
• my-skill
Does X when the user asks about Y.
/home/user/project/.hive/skills/my-skill/SKILL.md
USER SKILLS
────────────────────────────────────
• deep-research
Multi-step web research with source verification.
/home/user/.hive/skills/deep-research/SKILL.md
```
## Where to put skills
Hive scans five directories at startup, in this precedence order:
| Scope | Path | Use case |
|-------|------|----------|
| Project (Hive) | `<project>/.hive/skills/` | Skills specific to this repo |
| Project (cross-client) | `<project>/.agents/skills/` | Skills shared across Claude Code, Cursor, etc. |
| User (Hive) | `~/.hive/skills/` | Personal skills available in all projects |
| User (cross-client) | `~/.agents/skills/` | Personal cross-client skills |
| Framework | *(built-in)* | Default operational skills shipped with Hive |
**Precedence**: If two skills share the same name, the higher-precedence location wins. A project-level `code-review` skill overrides a user-level one with the same name.
**Cross-client paths**: The `.agents/skills/` directories are a convention shared across compatible agents. A skill installed at `~/.agents/skills/pdf-processing/` is visible to Hive, Claude Code, Cursor, and other compatible tools simultaneously.
## Creating a skill
### Directory structure
```
my-skill/
├── SKILL.md # Required — metadata + instructions
├── scripts/ # Optional — executable code
│ └── run.py
├── references/ # Optional — supplementary docs
│ └── api-reference.md
└── assets/ # Optional — templates, data files
└── template.json
```
### SKILL.md format
Every skill needs a `SKILL.md` with YAML frontmatter and a markdown body:
```markdown
---
name: my-skill
description: Extract and summarize PDF documents. Use when the user mentions PDFs or document extraction.
---
# PDF Processing
## When to use
Use this skill when the user needs to extract text from PDFs or merge documents.
## Steps
1. Check if pdfplumber is available...
2. Extract text using...
## Edge cases
- Scanned PDFs need OCR first...
```
### Frontmatter fields
| Field | Required | Description |
|-------|----------|-------------|
| `name` | Yes | Lowercase letters, numbers, hyphens. Must match the parent directory name. Max 64 chars. |
| `description` | Yes | What the skill does and when to use it. Max 1024 chars. Include keywords that help the agent match tasks. |
| `license` | No | License name or reference to a bundled LICENSE file. |
| `compatibility` | No | Environment requirements (e.g., "Requires git, docker"). |
| `metadata` | No | Arbitrary key-value pairs (author, version, etc.). |
| `allowed-tools` | No | Space-delimited list of pre-approved tools. |
### Writing good descriptions
The description is critical — it's what the agent uses to decide whether to activate a skill. Be specific:
```yaml
# Good — tells the agent what and when
description: Extract text and tables from PDF files, fill PDF forms, and merge multiple PDFs. Use when working with PDF documents or when the user mentions PDFs, forms, or document extraction.
# Bad — too vague for the agent to match
description: Helps with PDFs.
```
### Writing good instructions
The markdown body is loaded into the agent's context when the skill is activated. Tips:
- **Be procedural**: Step-by-step instructions work better than abstract descriptions.
- **Keep it focused**: Stay under 500 lines / 5000 tokens. Move detailed reference material to `references/`.
- **Use relative paths**: Reference bundled files with relative paths (`scripts/run.py`, `references/guide.md`).
- **Include examples**: Show sample inputs and expected outputs.
- **Cover edge cases**: Tell the agent what to do when things go wrong.
## How skills are activated
Skills use **progressive disclosure** — three tiers that keep context usage efficient:
### Tier 1: Catalog (always loaded)
At session start, the agent sees a compact catalog of all available skills (name + description only, ~50-100 tokens each). This is how it knows what skills exist.
### Tier 2: Instructions (on demand)
When the agent determines a skill is relevant to the current task, it reads the full `SKILL.md` body into context. This happens automatically — the agent matches the task against skill descriptions and activates the best fit.
### Tier 3: Resources (on demand)
When skill instructions reference supporting files (`scripts/extract.py`, `references/api-docs.md`), the agent reads those individually as needed.
### Pre-activated skills
Some agents are configured to load specific skills at session start (skipping the catalog phase). This is set in the agent's configuration:
```python
# In agent definition
skills = ["code-review", "deep-research"]
```
Pre-activated skills have their full instructions loaded from the start, without waiting for the agent to decide they're relevant.
## Trust and security
### Why trust gating exists
Project-level skills come from the repository being worked on. If you clone an untrusted repo that contains a `.hive/skills/` directory, those skills could inject instructions into the agent's system prompt. Trust gating prevents this.
**User-level and framework skills are always trusted.** Only project-scope skills go through trust gating.
### What happens with untrusted project skills
When Hive encounters project-level skills from a repo you haven't trusted before, it shows a consent prompt:
```
============================================================
SKILL TRUST REQUIRED
============================================================
The project at /home/user/new-project wants to load 2 skill(s)
that will inject instructions into the agent's system prompt.
Source: github.com/org/new-project
Skills requesting access:
• deploy-pipeline
"Automated deployment workflow for this project."
/home/user/new-project/.hive/skills/deploy-pipeline/SKILL.md
• code-standards
"Project-specific coding standards and review checklist."
/home/user/new-project/.hive/skills/code-standards/SKILL.md
Options:
1) Trust this session only
2) Trust permanently — remember for future runs
3) Deny — skip all project-scope skills from this repo
────────────────────────────────────────────────────────────
Select option (1-3):
```
### Trust a repo via CLI
To trust a repo permanently without the interactive prompt:
```bash
hive skill trust /path/to/project
```
This stores the trust decision in `~/.hive/trusted_repos.json`, keyed by the normalized git remote URL (e.g., `github.com/org/repo`).
### Automatic trust
Some repos are trusted automatically:
- **No git repo**: Directories without `.git/` are always trusted.
- **No remote**: Local-only git repos (no `origin` remote) are always trusted.
- **Localhost remotes**: Repos with `localhost`/`127.0.0.1` remotes are always trusted.
- **Own-remote patterns**: Repos matching patterns in `~/.hive/own_remotes` or the `HIVE_OWN_REMOTES` env var are always trusted.
### Configure own-remote patterns
If you trust all repos from your organization:
```bash
# Via file (one pattern per line)
echo "github.com/my-org/*" >> ~/.hive/own_remotes
echo "gitlab.com/my-team/*" >> ~/.hive/own_remotes
# Via environment variable (comma-separated)
export HIVE_OWN_REMOTES="github.com/my-org/*,github.com/my-corp/*"
```
### CI / headless environments
In non-interactive environments, untrusted project skills are silently skipped. To trust them explicitly:
```bash
export HIVE_TRUST_PROJECT_SKILLS=1
hive run my-agent
```
## Default skills
Hive ships with six built-in operational skills that provide runtime resilience. These are always loaded (unless disabled) and appear as "Operational Protocols" in the agent's system prompt.
| Skill | Purpose |
|-------|---------|
| `hive.note-taking` | Structured working notes in shared memory |
| `hive.batch-ledger` | Track per-item status in batch operations |
| `hive.context-preservation` | Save context before context window pruning |
| `hive.quality-monitor` | Self-assess output quality periodically |
| `hive.error-recovery` | Structured error classification and recovery |
| `hive.task-decomposition` | Break complex tasks into subtasks |
### Disable default skills
In your agent configuration:
```python
# Disable a specific default skill
default_skills = {
"hive.quality-monitor": {"enabled": False},
}
# Disable all default skills
default_skills = {
"_all": {"enabled": False},
}
```
## Environment variables
| Variable | Description |
|----------|-------------|
| `HIVE_TRUST_PROJECT_SKILLS=1` | Bypass trust gating for all project-level skills (CI override) |
| `HIVE_OWN_REMOTES` | Comma-separated glob patterns for auto-trusted remotes (e.g., `github.com/myorg/*`) |
## Compatibility with other agents
Skills written for any Agent Skills-compatible agent work in Hive:
- Place them in `.agents/skills/` (cross-client) or `.hive/skills/` (Hive-specific).
- The `SKILL.md` format is identical across Claude Code, Cursor, Gemini CLI, and others.
- Skills installed at `~/.agents/skills/` are visible to all compatible agents on your machine.
See the [Agent Skills specification](https://agentskills.io/specification) for the full format reference.

Some files were not shown because too many files have changed in this diff Show More